emanuelaboros commited on
Commit
f729b09
·
1 Parent(s): 52d99b3

testin the trick

Browse files
Files changed (1) hide show
  1. modeling_stacked.py +14 -9
modeling_stacked.py CHANGED
@@ -27,20 +27,20 @@ def get_info(label_map):
27
  # return cls()
28
 
29
 
30
- class SafeFloretWrapper(nn.Module):
31
  """
32
  A safe wrapper for floret model that keeps it off-device to avoid segmentation faults.
 
33
  """
34
 
35
- def __init__(self, floret_model):
36
- super().__init__()
37
- self.floret_model = floret_model
38
 
39
- def forward(self, texts):
40
  # Floret expects strings, not tensors
41
- _, predictions = self.floret_model.predict([texts], k=1)
42
- # Convert predictions to tensors for Hugging Face compatibility
43
- return torch.tensor(predictions)
44
 
45
 
46
  class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
@@ -53,7 +53,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
53
 
54
  # Load floret model
55
  self.dummy_param = nn.Parameter(torch.zeros(1))
56
- self.model_floret = floret.load_model(self.config.filename)
57
  # self.model_floret = SafeFloretWrapper(model_floret)
58
  # input_ids = "this is a text"
59
 
@@ -72,6 +72,11 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
72
  texts = input_ids
73
  else:
74
  raise ValueError(f"Unexpected input type: {type(input_ids)}")
 
 
 
 
 
75
  # print(self.model_floret(input_ids))
76
  # if input_ids is not None:
77
  # tokenizer = kwargs.get("tokenizer")
 
27
  # return cls()
28
 
29
 
30
+ class SafeFloretWrapper:
31
  """
32
  A safe wrapper for floret model that keeps it off-device to avoid segmentation faults.
33
+ This class is pure Python and never interacts with PyTorch tensors or devices.
34
  """
35
 
36
+ def __init__(self, model_path):
37
+ print(f"Loading floret model from {model_path}")
38
+ self.model_floret = floret.load_model(model_path)
39
 
40
+ def predict(self, texts, k=1):
41
  # Floret expects strings, not tensors
42
+ predictions, probabilities = self.model_floret.predict(texts, k=k)
43
+ return predictions, probabilities
 
44
 
45
 
46
  class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
 
53
 
54
  # Load floret model
55
  self.dummy_param = nn.Parameter(torch.zeros(1))
56
+ self.safe_floret = SafeFloretWrapper(self.config.filename)
57
  # self.model_floret = SafeFloretWrapper(model_floret)
58
  # input_ids = "this is a text"
59
 
 
72
  texts = input_ids
73
  else:
74
  raise ValueError(f"Unexpected input type: {type(input_ids)}")
75
+
76
+ # Use the SafeFloretWrapper to get predictions
77
+ predictions, probabilities = self.safe_floret.predict(texts)
78
+ print(f"Predictions: {predictions}")
79
+ print(f"Probabilities: {probabilities}")
80
  # print(self.model_floret(input_ids))
81
  # if input_ids is not None:
82
  # tokenizer = kwargs.get("tokenizer")