emanuelaboros commited on
Commit
a0e83ff
·
1 Parent(s): 200f4e7

testin the trick

Browse files
Files changed (1) hide show
  1. modeling_stacked.py +4 -4
modeling_stacked.py CHANGED
@@ -53,7 +53,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
53
 
54
  # Load floret model
55
  self.dummy_param = nn.Parameter(torch.zeros(1))
56
- self.safe_floret = SafeFloretWrapper(self.config.config.filename)
57
  # print(self.config.config)
58
  print(type(self.config))
59
  print(self.config.config.filename)
@@ -66,7 +66,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
66
  def forward(self, input_ids, attention_mask=None, **kwargs):
67
  # Convert input_ids to strings using tokenizer
68
  print(
69
- f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.safe_floret)}"
70
  )
71
  if isinstance(input_ids, str):
72
  # If the input is a single string, make it a list for floret
@@ -77,7 +77,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
77
  raise ValueError(f"Unexpected input type: {type(input_ids)}")
78
 
79
  # Use the SafeFloretWrapper to get predictions
80
- predictions, probabilities = self.safe_floret.predict(texts)
81
  print(f"Predictions: {predictions}")
82
  print(f"Probabilities: {probabilities}")
83
  # print(self.model_floret(input_ids))
@@ -94,7 +94,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
94
  # return torch.tensor(predictions)
95
  # else:
96
  # If no text is found, return dummy output
97
- return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes)
98
 
99
  def state_dict(self, *args, **kwargs):
100
  # Return an empty state dictionary
 
53
 
54
  # Load floret model
55
  self.dummy_param = nn.Parameter(torch.zeros(1))
56
+ self.model_floret = floret.load_model(self.config.config.filename)
57
  # print(self.config.config)
58
  print(type(self.config))
59
  print(self.config.config.filename)
 
66
  def forward(self, input_ids, attention_mask=None, **kwargs):
67
  # Convert input_ids to strings using tokenizer
68
  print(
69
+ f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
70
  )
71
  if isinstance(input_ids, str):
72
  # If the input is a single string, make it a list for floret
 
77
  raise ValueError(f"Unexpected input type: {type(input_ids)}")
78
 
79
  # Use the SafeFloretWrapper to get predictions
80
+ predictions, probabilities = self.model_floret.predict(texts, k=1)
81
  print(f"Predictions: {predictions}")
82
  print(f"Probabilities: {probabilities}")
83
  # print(self.model_floret(input_ids))
 
94
  # return torch.tensor(predictions)
95
  # else:
96
  # If no text is found, return dummy output
97
+ return predictions # Dummy tensor with shape (batch_size, num_classes)
98
 
99
  def state_dict(self, *args, **kwargs):
100
  # Return an empty state dictionary