Fralet commited on
Commit
2992308
1 Parent(s): 245e663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -198,19 +198,28 @@ if st.button("Predict Personality"):
198
  st.write("No predictions exceed the confidence threshold.")
199
  """
200
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
201
- nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
 
 
 
 
 
 
202
  tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
203
 
204
  premise = 'A few years ago, I was juggling a demanding job, volunteer commitments, and personal relationships, all while trying to manage chronic health issues. The challenge was overwhelming at times, but I approached it by prioritizing open communication with my employer and loved ones about my limits. I learned to delegate and accept help, which was difficult for me as I usually prefer to keep the peace by handling things myself. This experience taught me the importance of setting boundaries and the strength in vulnerability.'
205
- hypothesis = f'This example is Helper.'
206
 
207
- # run through model pre-trained on MNLI
208
- x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
209
- truncation_strategy='only_first')
210
- logits = nli_model(x.to(device))[0]
211
 
212
- # we throw away "neutral" (dim 1) and take the probability of
213
- # "entailment" (2) as the probability of the label being true
214
- entail_contradiction_logits = logits[:,[0,2]]
 
 
215
  probs = entail_contradiction_logits.softmax(dim=1)
216
- prob_label_is_true = probs[:,1]
 
 
 
 
198
  st.write("No predictions exceed the confidence threshold.")
199
  """
200
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
201
+ import torch
202
+
203
+ # Check if CUDA is available, otherwise use CPU
204
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
205
+
206
+ # Load the model and tokenizer
207
+ nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli').to(device)
208
  tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
209
 
210
  premise = 'A few years ago, I was juggling a demanding job, volunteer commitments, and personal relationships, all while trying to manage chronic health issues. The challenge was overwhelming at times, but I approached it by prioritizing open communication with my employer and loved ones about my limits. I learned to delegate and accept help, which was difficult for me as I usually prefer to keep the peace by handling things myself. This experience taught me the importance of setting boundaries and the strength in vulnerability.'
211
+ hypothesis = 'This example is Helper.'
212
 
213
+ # Tokenize the input text pair
214
+ inputs = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation_strategy='only_first').to(device)
 
 
215
 
216
+ # Perform inference
217
+ logits = nli_model(inputs)[0]
218
+
219
+ # Process logits to get probabilities
220
+ entail_contradiction_logits = logits[:, [0, 2]]
221
  probs = entail_contradiction_logits.softmax(dim=1)
222
+ prob_label_is_true = probs[:, 1]
223
+
224
+ # Print the probability that the label is true
225
+ print(prob_label_is_true)