w601sxs commited on
Commit
84280ca
·
1 Parent(s): e477c81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -7,6 +7,13 @@ from trl import SFTTrainer
7
  # import torch
8
  from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
9
 
 
 
 
 
 
 
 
10
  class KeywordsStoppingCriteria(StoppingCriteria):
11
  def __init__(self, keywords_ids:list):
12
  self.keywords = keywords_ids
@@ -42,11 +49,7 @@ probs_to_label = [
42
 
43
  ]
44
 
45
- ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
46
 
47
- tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")
48
-
49
- ref_model.eval()
50
 
51
 
52
  def get_tokens_and_labels(prompt):
 
7
  # import torch
8
  from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
9
 
10
+
11
+ ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")
14
+
15
+ ref_model.eval()
16
+
17
  class KeywordsStoppingCriteria(StoppingCriteria):
18
  def __init__(self, keywords_ids:list):
19
  self.keywords = keywords_ids
 
49
 
50
  ]
51
 
 
52
 
 
 
 
53
 
54
 
55
  def get_tokens_and_labels(prompt):