vtrv.vls commited on
Commit
6f92fa3
·
1 Parent(s): 6609139
Files changed (2) hide show
  1. app.py +12 -3
  2. models.py +26 -1
app.py CHANGED
@@ -6,11 +6,12 @@ from datetime import datetime
6
  import pandas as pd
7
 
8
  from utils import generate, send_to_s3
9
- from models import get_tinyllama, response_tinyllama
10
  from constants import css, js_code, js_light
11
 
12
  MERA_table = None
13
- TINY_LLAMA = None
 
14
 
15
  S3_SESSION = None
16
 
@@ -28,6 +29,14 @@ def tiny_gen(content, chat_history):
28
  send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
29
  return '', chat_history
30
 
 
 
 
 
 
 
 
 
31
  def tab_arena():
32
  with gradio.Row():
33
  with gradio.Column():
@@ -51,7 +60,7 @@ def tab_arena():
51
  # return "", chat_history
52
 
53
  msg.submit(giga_gen, [msg, chatbot_left], [msg, chatbot_left])
54
- msg.submit(tiny_gen, [msg, chatbot_right], [msg, chatbot_right])
55
 
56
  # with gradio.Column():
57
  # gradio.ChatInterface(
 
6
  import pandas as pd
7
 
8
  from utils import generate, send_to_s3
9
+ from models import get_tinyllama, response_tinyllama, response_qwen2ins1b
10
  from constants import css, js_code, js_light
11
 
12
  MERA_table = None
13
+ TINYLLAMA = None
14
+ QWEN2INS1B = None
15
 
16
  S3_SESSION = None
17
 
 
29
  send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
30
  return '', chat_history
31
 
32
+ def qwen_gen(content, chat_history):
33
+ chat_history.append([content])
34
+ res = response_qwen2ins1b(QWEN2INS1B, chat_history)
35
+ chat_history[-1].append(res)
36
+ send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
37
+ return '', chat_history
38
+
39
+
40
  def tab_arena():
41
  with gradio.Row():
42
  with gradio.Column():
 
60
  # return "", chat_history
61
 
62
  msg.submit(giga_gen, [msg, chatbot_left], [msg, chatbot_left])
63
+ msg.submit(qwen_gen, [msg, chatbot_right], [msg, chatbot_right])
64
 
65
  # with gradio.Column():
66
  # gradio.ChatInterface(
models.py CHANGED
@@ -5,6 +5,10 @@ def get_tinyllama():
5
  tinyllama = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
6
  return tinyllama
7
 
 
 
 
 
8
  def response_tinyllama(
9
  model=None,
10
  messages=None
@@ -24,4 +28,25 @@ def response_tinyllama(
24
  prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
25
  outputs = model(prompt, max_new_tokens=64, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
26
 
27
- return outputs[0]['generated_text'].split('<|assistant|>')[1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  tinyllama = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
6
  return tinyllama
7
 
8
+ def get_qwen2ins1b():
9
+ tinyllama = pipeline("text-generation", model="Qwen/Qwen2-1.5B-Instruct", torch_dtype=torch.float16, device_map="auto")
10
+ return tinyllama
11
+
12
  def response_tinyllama(
13
  model=None,
14
  messages=None
 
28
  prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
29
  outputs = model(prompt, max_new_tokens=64, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
30
 
31
+ return outputs[0]['generated_text'].split('<|assistant|>')[1].strip()
32
+
33
+ def response_qwen2ins1b(
34
+ model=None,
35
+ messages=None
36
+ ):
37
+
38
+ messages_dict = [
39
+ {
40
+ "role": "system",
41
+ "content": "You are a friendly and helpful chatbot",
42
+ }
43
+ ]
44
+ for step in messages:
45
+ messages_dict.append({'role': 'user', 'content': step[0]})
46
+ if len(step) >= 2:
47
+ messages_dict.append({'role': 'assistant', 'content': step[1]})
48
+
49
+ prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
50
+ outputs = model(prompt, max_new_tokens=64, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
51
+
52
+ return outputs[0]['generated_text'] #.split('<|assistant|>')[1].strip()