|
print("start to run") |
|
import streamlit as st |
|
import os |
|
os.system("pip install torch transformers sentencepiece accelerate torch.utils torchvision torch") |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
print("[code] All module has imported.") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16) |
|
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b") |
|
print("[code] model loaded") |
|
|
|
def generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty): |
|
inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
tokens = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
output = tokenizer.decode(tokens[0], skip_special_tokens=True) |
|
return output |
|
|
|
|
|
st.title("Causal Language Modeling") |
|
st.write("AIによる文章生成") |
|
|
|
|
|
input_text = st.text_area("入力テキスト") |
|
max_new_tokens = st.slider("生成する最大トークン数", min_value=1, max_value=512, value=64) |
|
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7) |
|
top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9) |
|
repetition_penalty = st.slider("Repetition Penalty", min_value=0.1, max_value=2.0, value=1.05) |
|
|
|
|
|
if st.button("生成"): |
|
output = generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty) |
|
st.write("生成されたテキスト:") |
|
st.write(output) |
|
|