Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -197,29 +197,44 @@ if st.button("Predict Personality"):
|
|
197 |
if not displayed:
|
198 |
st.write("No predictions exceed the confidence threshold.")
|
199 |
"""
|
|
|
200 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
201 |
import torch
|
202 |
|
203 |
-
#
|
|
|
|
|
|
|
|
|
204 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
205 |
|
206 |
-
#
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
|
213 |
-
#
|
214 |
-
|
|
|
215 |
|
216 |
-
#
|
217 |
-
|
|
|
|
|
218 |
|
219 |
-
#
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
|
224 |
-
|
225 |
-
|
|
|
197 |
if not displayed:
|
198 |
st.write("No predictions exceed the confidence threshold.")
|
199 |
"""
|
200 |
+
import streamlit as st
|
201 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
202 |
import torch
|
203 |
|
204 |
+
# Initialize the tokenizer and model
|
205 |
+
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
|
206 |
+
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
|
207 |
+
|
208 |
+
# Move model to appropriate device
|
209 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
210 |
+
model.to(device)
|
211 |
|
212 |
+
# Streamlit app
|
213 |
+
def main():
|
214 |
+
st.title("Natural Language Inference with BART")
|
215 |
+
|
216 |
+
# Text input for premise and hypothesis
|
217 |
+
premise = st.text_area("Enter the premise:", value="", height=150)
|
218 |
+
hypothesis = st.text_input("Enter the hypothesis:")
|
219 |
|
220 |
+
if st.button("Analyze"):
|
221 |
+
if premise and hypothesis:
|
222 |
+
# Tokenize and encode the premise and hypothesis
|
223 |
+
inputs = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation_strategy='only_first').to(device)
|
224 |
|
225 |
+
# Model inference
|
226 |
+
with torch.no_grad():
|
227 |
+
logits = model(inputs)[0]
|
228 |
|
229 |
+
# Calculate probabilities
|
230 |
+
entail_contradiction_logits = logits[:, [0, 2]]
|
231 |
+
probs = entail_contradiction_logits.softmax(dim=1)
|
232 |
+
prob_label_is_true = probs[:, 1].item()
|
233 |
|
234 |
+
# Display results
|
235 |
+
st.write(f"Probability of Entailment: {prob_label_is_true:.4f}")
|
236 |
+
else:
|
237 |
+
st.error("Please enter both a premise and a hypothesis.")
|
238 |
|
239 |
+
if __name__ == "__main__":
|
240 |
+
main()
|