Commit
·
a0e83ff
1
Parent(s):
200f4e7
testin the trick
Browse files- modeling_stacked.py +4 -4
modeling_stacked.py
CHANGED
@@ -53,7 +53,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
53 |
|
54 |
# Load floret model
|
55 |
self.dummy_param = nn.Parameter(torch.zeros(1))
|
56 |
-
self.
|
57 |
# print(self.config.config)
|
58 |
print(type(self.config))
|
59 |
print(self.config.config.filename)
|
@@ -66,7 +66,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
66 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
67 |
# Convert input_ids to strings using tokenizer
|
68 |
print(
|
69 |
-
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.
|
70 |
)
|
71 |
if isinstance(input_ids, str):
|
72 |
# If the input is a single string, make it a list for floret
|
@@ -77,7 +77,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
77 |
raise ValueError(f"Unexpected input type: {type(input_ids)}")
|
78 |
|
79 |
# Use the SafeFloretWrapper to get predictions
|
80 |
-
predictions, probabilities = self.
|
81 |
print(f"Predictions: {predictions}")
|
82 |
print(f"Probabilities: {probabilities}")
|
83 |
# print(self.model_floret(input_ids))
|
@@ -94,7 +94,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
94 |
# return torch.tensor(predictions)
|
95 |
# else:
|
96 |
# If no text is found, return dummy output
|
97 |
-
return
|
98 |
|
99 |
def state_dict(self, *args, **kwargs):
|
100 |
# Return an empty state dictionary
|
|
|
53 |
|
54 |
# Load floret model
|
55 |
self.dummy_param = nn.Parameter(torch.zeros(1))
|
56 |
+
self.model_floret = floret.load_model(self.config.config.filename)
|
57 |
# print(self.config.config)
|
58 |
print(type(self.config))
|
59 |
print(self.config.config.filename)
|
|
|
66 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
67 |
# Convert input_ids to strings using tokenizer
|
68 |
print(
|
69 |
+
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
|
70 |
)
|
71 |
if isinstance(input_ids, str):
|
72 |
# If the input is a single string, make it a list for floret
|
|
|
77 |
raise ValueError(f"Unexpected input type: {type(input_ids)}")
|
78 |
|
79 |
# Use the SafeFloretWrapper to get predictions
|
80 |
+
predictions, probabilities = self.model_floret.predict(texts, k=1)
|
81 |
print(f"Predictions: {predictions}")
|
82 |
print(f"Probabilities: {probabilities}")
|
83 |
# print(self.model_floret(input_ids))
|
|
|
94 |
# return torch.tensor(predictions)
|
95 |
# else:
|
96 |
# If no text is found, return dummy output
|
97 |
+
return predictions # Dummy tensor with shape (batch_size, num_classes)
|
98 |
|
99 |
def state_dict(self, *args, **kwargs):
|
100 |
# Return an empty state dictionary
|