vtrv.vls commited on
Commit
2375d69
·
1 Parent(s): 57a81f5
Files changed (3) hide show
  1. app.py +5 -7
  2. models.py +18 -14
  3. utils.py +1 -1
app.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
  import pandas as pd
7
 
8
  from utils import generate, send_to_s3
9
- from models import get_tiny_llama, response_tiny_llama
10
  from constants import css, js_code, js_light
11
 
12
  MERA_table = None
@@ -16,18 +16,16 @@ S3_SESSION = None
16
 
17
  def giga_gen(content, chat_history):
18
  chat_history.append([content])
19
- print(chat_history)
20
  res = generate(chat_history,'auth_token.json')
21
  chat_history[-1].append(res)
22
  send_to_s3(res, f'protobench/giga_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
23
- print(chat_history)
24
  return '', chat_history
25
 
26
  def tiny_gen(content, chat_history):
27
- res = response_tiny_llama(TINY_LLAMA, content)
28
- chat_history.append((content, res))
 
29
  send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
30
- print(chat_history)
31
  return '', chat_history
32
 
33
  def tab_arena():
@@ -158,7 +156,7 @@ if __name__ == "__main__":
158
  # data_load(args.result_file)
159
  # TYPES = ["number", "markdown", "number"]
160
 
161
- TINY_LLAMA = get_tiny_llama()
162
 
163
  try:
164
  session = boto3.session.Session()
 
6
  import pandas as pd
7
 
8
  from utils import generate, send_to_s3
9
+ from models import get_tinyllama, response_tinyllama
10
  from constants import css, js_code, js_light
11
 
12
  MERA_table = None
 
16
 
17
  def giga_gen(content, chat_history):
18
  chat_history.append([content])
 
19
  res = generate(chat_history,'auth_token.json')
20
  chat_history[-1].append(res)
21
  send_to_s3(res, f'protobench/giga_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
 
22
  return '', chat_history
23
 
24
  def tiny_gen(content, chat_history):
25
+ chat_history.append([content])
26
+ res = response_tinyllama(TINY_LLAMA, content)
27
+ chat_history[-1].append(res)
28
  send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
 
29
  return '', chat_history
30
 
31
  def tab_arena():
 
156
  # data_load(args.result_file)
157
  # TYPES = ["number", "markdown", "number"]
158
 
159
+ TINY_LLAMA = get_tinyllama()
160
 
161
  try:
162
  session = boto3.session.Session()
models.py CHANGED
@@ -1,23 +1,27 @@
1
  import torch
2
  from transformers import pipeline
3
 
4
- def get_tiny_llama():
5
- pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
6
- return pipe
7
 
8
- def response_tiny_llama(
9
- pipe=None,
10
- content="How many helicopters can a human eat in one sitting?"
11
  ):
12
- # We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
13
- messages = [
14
  {
15
  "role": "system",
16
- "content": "You are a friendly chatbot who always responds in the style of a pirate",
17
- },
18
- {"role": "user", "content": content},
19
  ]
20
- prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
21
- outputs = pipe(prompt, max_new_tokens=32, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
 
 
22
 
23
- return outputs[0]['generated_text'].split('<|assistant|>')[1]
 
 
 
 
1
  import torch
2
  from transformers import pipeline
3
 
4
+ def get_tinyllama():
5
+ tinyllama = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
6
+ return tinyllama
7
 
8
+ def response_tinyllama(
9
+ model=None,
10
+ messages=None
11
  ):
12
+
13
+ messages_dict = [
14
  {
15
  "role": "system",
16
+ "content": "You are a friendly and helpful chatbot",
17
+ }
 
18
  ]
19
+ for step in messages:
20
+ messages_dict.append({'role': 'user', 'content': step[0]})
21
+ if len(step) >= 2:
22
+ messages_dict.append({'role': 'assistant', 'content': step[1]})
23
 
24
+ prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
25
+ outputs = model(prompt, max_new_tokens=32, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
26
+
27
+ return outputs[0]['generated_text'].split('<|assistant|>')[1].strip()
utils.py CHANGED
@@ -57,7 +57,7 @@ def generate(content=None, auth_file=None):
57
  auth_token = json.load(f)
58
 
59
  if datetime.fromtimestamp(auth_token['expires_at']/1000) <= datetime.now() - timedelta(seconds=60):
60
- gen_auth_token()
61
  with open(auth_file) as f:
62
  auth_token = json.load(f)
63
 
 
57
  auth_token = json.load(f)
58
 
59
  if datetime.fromtimestamp(auth_token['expires_at']/1000) <= datetime.now() - timedelta(seconds=60):
60
+ gen_auth_token(auth_file)
61
  with open(auth_file) as f:
62
  auth_token = json.load(f)
63