Azure99 commited on
Commit
86adf77
1 Parent(s): 99add8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -9
app.py CHANGED
@@ -6,18 +6,12 @@ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  MAX_INPUT_LIMIT = 3584
 
9
  MODEL_NAME = "Azure99/blossom-v5.1-9b"
10
 
11
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
 
14
- GENERATE_CONFIG = dict(
15
- max_new_tokens=1536,
16
- temperature=0.5,
17
- top_p=0.85,
18
- top_k=50,
19
- repetition_penalty=1.05
20
- )
21
 
22
  def get_input_ids(inst, history):
23
  prefix = ("A chat between a human and an artificial intelligence bot. "
@@ -38,15 +32,22 @@ def get_input_ids(inst, history):
38
 
39
 
40
  @spaces.GPU
41
- def chat(inst, history):
42
  with torch.no_grad():
43
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
44
  input_ids = get_input_ids(inst, history)
45
  if len(input_ids) > MAX_INPUT_LIMIT:
46
  yield "The input is too long, please clear the history."
47
  return
 
 
 
 
 
 
 
48
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
49
- streamer=streamer, **GENERATE_CONFIG)
50
  Thread(target=model.generate, kwargs=generation_kwargs).start()
51
 
52
  outputs = ""
@@ -55,6 +56,36 @@ def chat(inst, history):
55
  yield outputs
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  gr.ChatInterface(chat,
59
  chatbot=gr.Chatbot(show_label=False, height=500, show_copy_button=True, render_markdown=True),
60
  textbox=gr.Textbox(placeholder="", container=False, scale=7),
@@ -63,6 +94,8 @@ gr.ChatInterface(chat,
63
  '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
64
  theme="soft",
65
  examples=["Hello", "What is MBTI", "用Python实现二分查找", "为switch写一篇小红书种草文案,带上emoji"],
 
 
66
  clear_btn="🗑️Clear",
67
  undo_btn="↩️Undo",
68
  retry_btn="🔄Retry",
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  MAX_INPUT_LIMIT = 3584
9
+ MAX_NEW_TOKENS = 1536
10
  MODEL_NAME = "Azure99/blossom-v5.1-9b"
11
 
12
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
 
 
 
 
 
 
 
 
15
 
16
  def get_input_ids(inst, history):
17
  prefix = ("A chat between a human and an artificial intelligence bot. "
 
32
 
33
 
34
  @spaces.GPU
35
+ def chat(inst, history, temperature, top_p, repetition_penalty):
36
  with torch.no_grad():
37
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
38
  input_ids = get_input_ids(inst, history)
39
  if len(input_ids) > MAX_INPUT_LIMIT:
40
  yield "The input is too long, please clear the history."
41
  return
42
+ generate_config = dict(
43
+ max_new_tokens=MAX_NEW_TOKENS,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ repetition_penalty=repetition_penalty
47
+ )
48
+ print(generate_config)
49
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
50
+ streamer=streamer, **generate_config)
51
  Thread(target=model.generate, kwargs=generation_kwargs).start()
52
 
53
  outputs = ""
 
56
  yield outputs
57
 
58
 
59
+ additional_inputs = [
60
+ gr.Slider(
61
+ label="Temperature",
62
+ value=0.5,
63
+ minimum=0.0,
64
+ maximum=1.0,
65
+ step=0.05,
66
+ interactive=True,
67
+ info="Controls randomness in choosing words.",
68
+ ),
69
+ gr.Slider(
70
+ label="Top-P",
71
+ value=0.85,
72
+ minimum=0.0,
73
+ maximum=1.0,
74
+ step=0.05,
75
+ interactive=True,
76
+ info="Picks words until their combined probability is at least top_p.",
77
+ ),
78
+ gr.Slider(
79
+ label="Repetition penalty",
80
+ value=1.05,
81
+ minimum=1.0,
82
+ maximum=1.2,
83
+ step=0.01,
84
+ interactive=True,
85
+ info="Repetition Penalty: Controls how much repetition is penalized.",
86
+ )
87
+ ]
88
+
89
  gr.ChatInterface(chat,
90
  chatbot=gr.Chatbot(show_label=False, height=500, show_copy_button=True, render_markdown=True),
91
  textbox=gr.Textbox(placeholder="", container=False, scale=7),
 
94
  '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
95
  theme="soft",
96
  examples=["Hello", "What is MBTI", "用Python实现二分查找", "为switch写一篇小红书种草文案,带上emoji"],
97
+ additional_inputs=additional_inputs,
98
+ additional_inputs_accordion=gr.Accordion(label="Config", open=True),
99
  clear_btn="🗑️Clear",
100
  undo_btn="↩️Undo",
101
  retry_btn="🔄Retry",