Update mT5Model.py
Browse files- mT5Model.py +2 -0
mT5Model.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
from torch.nn.functional import softmax
|
2 |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
|
|
3 |
|
4 |
def process_nli(premise: str, hypothesis: str):
|
5 |
""" process to required xnli format with task prefix """
|
6 |
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
|
7 |
|
|
|
8 |
def setModel(model_name):
|
9 |
tokenizer = MT5Tokenizer.from_pretrained(model_name)
|
10 |
model = MT5ForConditionalGeneration.from_pretrained(model_name)
|
|
|
1 |
from torch.nn.functional import softmax
|
2 |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
3 |
+
import streamlit as st
|
4 |
|
5 |
def process_nli(premise: str, hypothesis: str):
|
6 |
""" process to required xnli format with task prefix """
|
7 |
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
|
8 |
|
9 |
+
@st.cache(allow_output_mutation=True)
|
10 |
def setModel(model_name):
|
11 |
tokenizer = MT5Tokenizer.from_pretrained(model_name)
|
12 |
model = MT5ForConditionalGeneration.from_pretrained(model_name)
|