Fralet commited on
Commit
2fe5850
1 Parent(s): 2f02ad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -197,29 +197,44 @@ if st.button("Predict Personality"):
197
  if not displayed:
198
  st.write("No predictions exceed the confidence threshold.")
199
  """
 
200
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
201
  import torch
202
 
203
- # Check if CUDA is available, otherwise use CPU
 
 
 
 
204
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
205
 
206
- # Load the model and tokenizer
207
- nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli').to(device)
208
- tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
 
 
 
 
209
 
210
- premise = 'A few years ago, I was juggling a demanding job, volunteer commitments, and personal relationships, all while trying to manage chronic health issues. The challenge was overwhelming at times, but I approached it by prioritizing open communication with my employer and loved ones about my limits. I learned to delegate and accept help, which was difficult for me as I usually prefer to keep the peace by handling things myself. This experience taught me the importance of setting boundaries and the strength in vulnerability.'
211
- hypothesis = 'This response is Helper personality.'
 
 
212
 
213
- # Tokenize the input text pair
214
- inputs = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation_strategy='only_first').to(device)
 
215
 
216
- # Perform inference
217
- logits = nli_model(inputs)[0]
 
 
218
 
219
- # Process logits to get probabilities
220
- entail_contradiction_logits = logits[:, [0, 2]]
221
- probs = entail_contradiction_logits.softmax(dim=1)
222
- prob_label_is_true = probs[:, 1]
223
 
224
- # Print the probability that the label is true
225
- print(prob_label_is_true)
 
197
  if not displayed:
198
  st.write("No predictions exceed the confidence threshold.")
199
  """
200
+ import streamlit as st
201
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
202
  import torch
203
 
204
+ # Initialize the tokenizer and model
205
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
206
+ model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
207
+
208
+ # Move model to appropriate device
209
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
210
+ model.to(device)
211
 
212
+ # Streamlit app
213
+ def main():
214
+ st.title("Natural Language Inference with BART")
215
+
216
+ # Text input for premise and hypothesis
217
+ premise = st.text_area("Enter the premise:", value="", height=150)
218
+ hypothesis = st.text_input("Enter the hypothesis:")
219
 
220
+ if st.button("Analyze"):
221
+ if premise and hypothesis:
222
+ # Tokenize and encode the premise and hypothesis
223
+ inputs = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation_strategy='only_first').to(device)
224
 
225
+ # Model inference
226
+ with torch.no_grad():
227
+ logits = model(inputs)[0]
228
 
229
+ # Calculate probabilities
230
+ entail_contradiction_logits = logits[:, [0, 2]]
231
+ probs = entail_contradiction_logits.softmax(dim=1)
232
+ prob_label_is_true = probs[:, 1].item()
233
 
234
+ # Display results
235
+ st.write(f"Probability of Entailment: {prob_label_is_true:.4f}")
236
+ else:
237
+ st.error("Please enter both a premise and a hypothesis.")
238
 
239
+ if __name__ == "__main__":
240
+ main()