Calvin commited on
Commit
229734c
·
1 Parent(s): ec3ec03

updated model and interface

Browse files
Files changed (1) hide show
  1. app.py +53 -16
app.py CHANGED
@@ -48,8 +48,8 @@ tokenizer.pad_token = tokenizer.eos_token
48
  model = PeftModel.from_pretrained(model, PEFT_MODEL)
49
 
50
  generation_config = model.generation_config
51
- generation_config.max_new_tokens = 200
52
- generation_config.temperature = 0.7
53
  generation_config.top_p = 0.7
54
  generation_config.num_return_sequences = 1
55
  generation_config.pad_token_id = tokenizer.eos_token_id
@@ -61,21 +61,58 @@ pipeline = transformers.pipeline(
61
  tokenizer=tokenizer,
62
  )
63
 
64
- def query_model(message, history):
65
 
66
- prompt = f"""
67
- <human>: {message}
68
- <assistant>:
69
- """.strip()
70
-
71
- result = pipeline(
72
- prompt,
73
- generation_config=generation_config,
74
- )
75
 
76
- # parsed_result = result[0]["generated_text"].split("<assistant>:")[1][1:]
 
77
 
78
- return parsed_result
 
 
 
 
 
79
 
80
-
81
- gr.ChatInterface(query_model, textbox=gr.Textbox(placeholder="Ask anything about Fetch!", container=False, scale=7),).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  model = PeftModel.from_pretrained(model, PEFT_MODEL)
49
 
50
  generation_config = model.generation_config
51
+ generation_config.max_new_tokens = 150
52
+ generation_config.temperature = 0.6
53
  generation_config.top_p = 0.7
54
  generation_config.num_return_sequences = 1
55
  generation_config.pad_token_id = tokenizer.eos_token_id
 
61
  tokenizer=tokenizer,
62
  )
63
 
64
+ def main():
65
 
66
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
67
 
68
+ def update_temp(temp):
69
+ generation_config.temperature = temp
70
 
71
+ def update_tokens(tokens):
72
+ generation_config.max_new_tokens = tokens
73
+
74
+ chatbot = gr.Chatbot(label="Fetch Rewards Chatbot")
75
+ temperature = gr.Slider(0, 1, value=0.6, step=0.1, label="Creativity", interactive=True)
76
+ temperature.change(fn=update_temp, inputs=temperature)
77
 
78
+ tokens = gr.Slider(50, 200, value=100, step=50, label="Length", interactive=True)
79
+ tokens.change(fn=update_tokens, inputs=tokens)
80
+
81
+ msg = gr.Textbox(label="", placeholder="Ask anything about Fetch!")
82
+ clear = gr.Button("Clear Log")
83
+
84
+ def user(user_message, history):
85
+ return "", history + [[user_message, None]]
86
+
87
+ def bot(history):
88
+
89
+ message = history[-1][0]
90
+ prompt = f"""
91
+ <human>: {message}
92
+ <assistant>:
93
+ """.strip()
94
+
95
+ result = pipeline(
96
+ prompt,
97
+ generation_config=generation_config,
98
+ )
99
+ # print(result)
100
+ parsed_result = result[0]["generated_text"].split("<assistant>:")[1][1:].split("\n")[0]
101
+
102
+ history[-1][1] = ""
103
+ for character in parsed_result:
104
+ history[-1][1] += character
105
+ time.sleep(0.01)
106
+ yield history
107
+
108
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
109
+ bot, chatbot, chatbot
110
+ )
111
+ clear.click(lambda: None, None, chatbot, queue=False)
112
+
113
+ demo.queue()
114
+ demo.launch()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()