w601sxs commited on
Commit
1a4e31a
·
1 Parent(s): a261c47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -2,19 +2,18 @@ import gradio as gr
2
  import torch
3
  from peft import PeftModel, PeftConfig, LoraConfig
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
5
  # import torch
6
  from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
7
- import numpy as np
8
- import spacy
9
- from spacy import displacy
10
- from spacy.tokens import Span
11
- from spacy.tokens import Doc
12
 
13
  ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
 
 
14
 
15
  tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")
16
 
17
- ref_model.eval()
18
 
19
  class KeywordsStoppingCriteria(StoppingCriteria):
20
  def __init__(self, keywords_ids:list):
@@ -30,6 +29,7 @@ stop_words = ['>', ' >','> ']
30
  stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
31
  stop_criteria = KeywordsStoppingCriteria(stop_ids)
32
 
 
33
 
34
  if tokenizer.pad_token_id is None:
35
  tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -43,16 +43,14 @@ probs_to_label = [
43
  (0.5, "50%"),
44
  (0.1, "10%"),
45
  (0.01, "1%"),
 
46
  ]
47
-
48
-
49
-
50
-
51
  def get_tokens_and_labels(prompt):
52
  """
53
  Given the prompt (text), return a list of tuples (decoded_token, label)
54
  """
55
- inputs = tokenizer([prompt], return_tensors="pt")
56
  outputs = ref_model.generate(
57
  **inputs,
58
  max_new_tokens=1000,
@@ -91,10 +89,45 @@ def get_tokens_and_labels(prompt):
91
 
92
  return highlighted_out
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
95
 
96
 
 
 
 
 
 
 
 
 
 
97
 
 
98
 
99
 
100
  def predict(text):
@@ -107,9 +140,9 @@ def predict(text):
107
 
108
 
109
  demo = gr.Interface(
110
- fn=get_tokens_and_labels,
111
  inputs='text',
112
- outputs='text',
113
  )
114
 
115
  demo.launch()
 
2
  import torch
3
  from peft import PeftModel, PeftConfig, LoraConfig
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from datasets import load_dataset
6
+ from trl import SFTTrainer
7
  # import torch
8
  from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
9
+
 
 
 
 
10
 
11
  ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
12
+ ref_model = ref_model.to('cuda')
13
+ ref_model.eval()
14
 
15
  tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")
16
 
 
17
 
18
  class KeywordsStoppingCriteria(StoppingCriteria):
19
  def __init__(self, keywords_ids:list):
 
29
  stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
30
  stop_criteria = KeywordsStoppingCriteria(stop_ids)
31
 
32
+ import numpy as np
33
 
34
  if tokenizer.pad_token_id is None:
35
  tokenizer.pad_token_id = tokenizer.eos_token_id
 
43
  (0.5, "50%"),
44
  (0.1, "10%"),
45
  (0.01, "1%"),
46
+
47
  ]
48
+ import numpy as np
 
 
 
49
  def get_tokens_and_labels(prompt):
50
  """
51
  Given the prompt (text), return a list of tuples (decoded_token, label)
52
  """
53
+ inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
54
  outputs = ref_model.generate(
55
  **inputs,
56
  max_new_tokens=1000,
 
89
 
90
  return highlighted_out
91
 
92
+ import spacy
93
+ from spacy import displacy
94
+ from spacy.tokens import Span
95
+ from spacy.tokens import Doc
96
+
97
+ def render_output(prompt):
98
+ output = get_tokens_and_labels(prompt)
99
+ nlp = spacy.blank("en")
100
+ doc = nlp(''.join([a[0] for a in output]).replace('Ġ',' ').replace('Ċ','\n'))
101
+ words = [a[0].replace('Ġ',' ').replace('Ċ','\n') for a in output]#[:indices[2]]
102
+ doc = Doc(nlp.vocab, words=words)
103
+
104
+ doc.spans["sc"]=[]
105
+ c = 0
106
+
107
+ for outs in output:
108
+ tmpouts = outs[0].replace('Ġ','').replace('Ċ','\n')
109
+ # print(c, "to", c+len(tmpouts)," : ", tmpouts)
110
+
111
+ if outs[1] is not None:
112
+ doc.spans["sc"].append(Span(doc, c, c+1, outs[1] ))
113
+
114
+ c+=1
115
 
116
+ # if c>indices[2]-1:
117
+ # break
118
 
119
 
120
+ options = {'colors' : {
121
+ '99%': '#44ce1b',
122
+ '95%': '#bbdb44',
123
+ '90%': '#f7e379',
124
+ '50%': '#fec12a',
125
+ '10%': '#f2a134',
126
+ '1%': '#e51f1f',
127
+ '': '#e51f1f',
128
+ }}
129
 
130
+ return displacy.render(doc, style="span", options = options)
131
 
132
 
133
  def predict(text):
 
140
 
141
 
142
  demo = gr.Interface(
143
+ fn=render_output,
144
  inputs='text',
145
+ outputs='html',
146
  )
147
 
148
  demo.launch()