import streamlit as st | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from model.funcs import execution_time | |
def load_model(): | |
model_path = "17/" | |
model_name = "sberbank-ai/rugpt3small_based_on_gpt2" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_path) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
def generate_text(promt): | |
promt = tokenizer.encode(promt, return_tensors="pt") | |
model.eval() | |
with torch.no_grad(): | |
out = model.generate( | |
promt, | |
do_sample=True, | |
num_beams=2, | |
temperature=1.5, | |
top_p=0.9, | |
max_length=150, | |
) | |
out = list(map(tokenizer.decode, out))[0] | |
return out | |
promt = st.text_input("Ask a question") | |
generate = st.button("Generate") | |
if generate: | |
if not promt: | |
st.write("42") | |
else: | |
st.write(generate_text(promt)) | |