Spaces:
Runtime error
Runtime error
import streamlit as st | |
from utils import get_roberta, get_gpt, get_distilbert | |
import torch | |
st.title('Sentence Entailment') | |
col1, col2 = st.columns([1,1]) | |
with col1: | |
sentence1 = st.text_input('Premise') | |
with col2: | |
sentence2 = st.text_input('Hypothesis') | |
btn = st.button("Submit") | |
label_dict = { | |
0 : 'entailment', | |
1 : 'neutral', | |
2 : 'contradiction' | |
} | |
if btn: | |
# Get Roberta Output | |
roberta_tokenizer, roberta_model = get_roberta() | |
roberta_input = roberta_tokenizer( | |
sentence1, | |
sentence2, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
roberta_logits = roberta_model(**roberta_input)['logits'] | |
st.write('ROBERTA', label_dict[roberta_logits.argmax().item()]) | |
distilbert_tokenizer, distilbert_model = get_distilbert() | |
distilbert_input = distilbert_tokenizer( | |
sentence1, | |
sentence2, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
distilbert_logits = distilbert_model(**distilbert_input)['logits'] | |
st.write('DistilBERT', label_dict[distilbert_logits.argmax().item()]) | |
# | |
gpt_tokenizer, gpt_model = get_gpt() | |
gpt_input = gpt_tokenizer( | |
sentence1 + ' [SEP] ' + sentence2, | |
truncation=True, | |
padding='max_length', | |
max_length=512, | |
return_tensors='pt' | |
) | |
gpt_logits = gpt_model(**gpt_input)['logits'] | |
st.write('GPT', label_dict[gpt_logits.argmax().item()]) |