testapp / app.py
thamada's picture
Create app.py
1788430 verified
raw
history blame
1.24 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def run_LLM (model, tokenizer, streamer, prompt):
token_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(
input_ids=token_ids.to(model.device),
#max_new_tokens=300,
max_new_tokens=3000000,
do_sample=True,
temperature=0.8,
)
n_tokens = len(output_ids[0])
output_text = tokenizer.decode(output_ids[0])
return (output_text, n_tokens)
def display_message():
model = AutoModelForCausalLM.from_pretrained("cyberagent/calm2-7b-chat",
device_map="cuda",
torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = """わが国の経済について今後の予想を教えてください。
ASSISTANT: """
(result, n_tokens) = run_LLM(model, tokenizer, streamer, prompt)
return result
if __name__ == '__main__':
iface = gr.Interface(fn=display_message, inputs=None, outputs="text")
iface.launch()