Commit
·
5d957b8
1
Parent(s):
3dd948f
testin the trick
Browse files- modeling_stacked.py +2 -2
- test_floret.py +7 -0
modeling_stacked.py
CHANGED
@@ -41,8 +41,8 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
41 |
|
42 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
43 |
# Convert input_ids to strings using tokenizer
|
44 |
-
print(f"Check if it arrives here: {input_ids}")
|
45 |
-
predictions, probabilities = self.model_floret.predict([input_ids], k=1)
|
46 |
# if input_ids is not None:
|
47 |
# tokenizer = kwargs.get("tokenizer")
|
48 |
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
|
|
41 |
|
42 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
43 |
# Convert input_ids to strings using tokenizer
|
44 |
+
print(f"Check if it arrives here: {input_ids}, ---, {type(input_ids)}")
|
45 |
+
# predictions, probabilities = self.model_floret.predict([input_ids], k=1)
|
46 |
# if input_ids is not None:
|
47 |
# tokenizer = kwargs.get("tokenizer")
|
48 |
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
test_floret.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import floret
|
2 |
+
|
3 |
+
model_floret = floret.load_model("LID-40-3-2000000-1-4.bin")
|
4 |
+
|
5 |
+
input_ids = 'this is a text'
|
6 |
+
print(model_floret.predict([input_ids], k=1))
|
7 |
+
|