Update mT5Model.py
Browse files- mT5Model.py +2 -1
mT5Model.py
CHANGED
@@ -16,6 +16,8 @@ def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_temp
|
|
16 |
NEUTRAL_LABEL = "▁1"
|
17 |
CONTRADICTS_LABEL = "▁2"
|
18 |
|
|
|
|
|
19 |
label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
|
20 |
|
21 |
# construct sequence of premise, hypothesis pairs
|
@@ -23,7 +25,6 @@ def runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_temp
|
|
23 |
# format for mt5 xnli task
|
24 |
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs]
|
25 |
|
26 |
-
model, tokenizer = setModel(model_name)
|
27 |
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
|
28 |
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1)
|
29 |
|
|
|
16 |
NEUTRAL_LABEL = "▁1"
|
17 |
CONTRADICTS_LABEL = "▁2"
|
18 |
|
19 |
+
model, tokenizer = setModel(model_name)
|
20 |
+
|
21 |
label_inds = tokenizer.convert_tokens_to_ids([ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
|
22 |
|
23 |
# construct sequence of premise, hypothesis pairs
|
|
|
25 |
# format for mt5 xnli task
|
26 |
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs]
|
27 |
|
|
|
28 |
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
|
29 |
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, num_beams=1)
|
30 |
|