thamada commited on
Commit
1788430
1 Parent(s): 98ac3bd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ def run_LLM (model, tokenizer, streamer, prompt):
6
+
7
+ token_ids = tokenizer.encode(prompt, return_tensors="pt")
8
+ output_ids = model.generate(
9
+ input_ids=token_ids.to(model.device),
10
+ #max_new_tokens=300,
11
+ max_new_tokens=3000000,
12
+ do_sample=True,
13
+ temperature=0.8,
14
+ )
15
+
16
+ n_tokens = len(output_ids[0])
17
+ output_text = tokenizer.decode(output_ids[0])
18
+
19
+ return (output_text, n_tokens)
20
+
21
+ def display_message():
22
+ model = AutoModelForCausalLM.from_pretrained("cyberagent/calm2-7b-chat",
23
+ device_map="cuda",
24
+ torch_dtype="auto")
25
+ tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
26
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
27
+
28
+ prompt = """わが国の経済について今後の予想を教えてください。
29
+ ASSISTANT: """
30
+
31
+ (result, n_tokens) = run_LLM(model, tokenizer, streamer, prompt)
32
+
33
+ return result
34
+
35
+
36
+ if __name__ == '__main__':
37
+
38
+ iface = gr.Interface(fn=display_message, inputs=None, outputs="text")
39
+ iface.launch()