Commit
·
f729b09
1
Parent(s):
52d99b3
testin the trick
Browse files- modeling_stacked.py +14 -9
modeling_stacked.py
CHANGED
@@ -27,20 +27,20 @@ def get_info(label_map):
|
|
27 |
# return cls()
|
28 |
|
29 |
|
30 |
-
class SafeFloretWrapper
|
31 |
"""
|
32 |
A safe wrapper for floret model that keeps it off-device to avoid segmentation faults.
|
|
|
33 |
"""
|
34 |
|
35 |
-
def __init__(self,
|
36 |
-
|
37 |
-
self.
|
38 |
|
39 |
-
def
|
40 |
# Floret expects strings, not tensors
|
41 |
-
|
42 |
-
|
43 |
-
return torch.tensor(predictions)
|
44 |
|
45 |
|
46 |
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
@@ -53,7 +53,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
53 |
|
54 |
# Load floret model
|
55 |
self.dummy_param = nn.Parameter(torch.zeros(1))
|
56 |
-
self.
|
57 |
# self.model_floret = SafeFloretWrapper(model_floret)
|
58 |
# input_ids = "this is a text"
|
59 |
|
@@ -72,6 +72,11 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
72 |
texts = input_ids
|
73 |
else:
|
74 |
raise ValueError(f"Unexpected input type: {type(input_ids)}")
|
|
|
|
|
|
|
|
|
|
|
75 |
# print(self.model_floret(input_ids))
|
76 |
# if input_ids is not None:
|
77 |
# tokenizer = kwargs.get("tokenizer")
|
|
|
27 |
# return cls()
|
28 |
|
29 |
|
30 |
+
class SafeFloretWrapper:
|
31 |
"""
|
32 |
A safe wrapper for floret model that keeps it off-device to avoid segmentation faults.
|
33 |
+
This class is pure Python and never interacts with PyTorch tensors or devices.
|
34 |
"""
|
35 |
|
36 |
+
def __init__(self, model_path):
|
37 |
+
print(f"Loading floret model from {model_path}")
|
38 |
+
self.model_floret = floret.load_model(model_path)
|
39 |
|
40 |
+
def predict(self, texts, k=1):
|
41 |
# Floret expects strings, not tensors
|
42 |
+
predictions, probabilities = self.model_floret.predict(texts, k=k)
|
43 |
+
return predictions, probabilities
|
|
|
44 |
|
45 |
|
46 |
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
|
53 |
|
54 |
# Load floret model
|
55 |
self.dummy_param = nn.Parameter(torch.zeros(1))
|
56 |
+
self.safe_floret = SafeFloretWrapper(self.config.filename)
|
57 |
# self.model_floret = SafeFloretWrapper(model_floret)
|
58 |
# input_ids = "this is a text"
|
59 |
|
|
|
72 |
texts = input_ids
|
73 |
else:
|
74 |
raise ValueError(f"Unexpected input type: {type(input_ids)}")
|
75 |
+
|
76 |
+
# Use the SafeFloretWrapper to get predictions
|
77 |
+
predictions, probabilities = self.safe_floret.predict(texts)
|
78 |
+
print(f"Predictions: {predictions}")
|
79 |
+
print(f"Probabilities: {probabilities}")
|
80 |
# print(self.model_floret(input_ids))
|
81 |
# if input_ids is not None:
|
82 |
# tokenizer = kwargs.get("tokenizer")
|