|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
def run_LLM (model, tokenizer, streamer, prompt): |
|
|
|
token_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
output_ids = model.generate( |
|
input_ids=token_ids.to(model.device), |
|
|
|
max_new_tokens=3000000, |
|
do_sample=True, |
|
temperature=0.8, |
|
) |
|
|
|
n_tokens = len(output_ids[0]) |
|
output_text = tokenizer.decode(output_ids[0]) |
|
|
|
return (output_text, n_tokens) |
|
|
|
def display_message(): |
|
model = AutoModelForCausalLM.from_pretrained("cyberagent/calm2-7b-chat", |
|
device_map="cuda", |
|
torch_dtype="auto") |
|
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat") |
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
prompt = """わが国の経済について今後の予想を教えてください。 |
|
ASSISTANT: """ |
|
|
|
(result, n_tokens) = run_LLM(model, tokenizer, streamer, prompt) |
|
|
|
return result |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
iface = gr.Interface(fn=display_message, inputs=None, outputs="text") |
|
iface.launch() |
|
|