Spaces:
Runtime error
Runtime error
File size: 3,354 Bytes
f8987ba 06452a1 f8987ba d6f4621 f8987ba 77b63e6 06452a1 f8987ba d6f4621 06452a1 d6f4621 f8987ba d6f4621 06452a1 77b63e6 8f192a0 06452a1 77b63e6 f8987ba 77b63e6 d6f4621 77b63e6 d6f4621 06452a1 d6f4621 77b63e6 d6f4621 77b63e6 d6f4621 06452a1 d6f4621 f8987ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import streamlit as st
from googletrans import Translator
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# from huggingface_hub import snapshot_download
page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"])
translator = Translator()
def load_model(model_name):
with st.spinner('Waiting for the model to load.....'):
# snapshot_download('flax-community/Sinhala-gpt2')
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
st.success('Model loaded!!')
return model, tokenizer
seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5)
max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100)
gen_bt = st.sidebar.button('Generate')
def generate(model, tokenizer, seed, seq_num, max_len):
sentences = []
input_ids = tokenizer.encode(seed, return_tensors='pt')
beam_outputs = model.generate(
input_ids,
do_sample=True,
max_length=max_len,
top_k=50,
top_p=0.95,
temperature=0.7,
num_return_sequences=seq_num,
no_repeat_ngram_size=2,
early_stopping=True
)
for beam_out in beam_outputs:
sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True))
return sentences
if page == 'Pretrained GPT2':
st.title('Sinhala Text generation with GPT2')
st.markdown('A simple demo using Sinhala-gpt2 model trained during hf-flax week')
model, tokenizer = load_model('flax-community/Sinhala-gpt2')
if gen_bt:
try:
with st.spinner('Generating...'):
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
seqs = generate(model, tokenizer, seed, seq_num, max_len)
for i, seq in enumerate(seqs):
st.info(f'Generated sequence {i+1}:')
st.write(seq)
st.info(f'English translation (by Google Translation):')
st.write(translator.translate(seq, src='si', dest='en').text)
except Exception as e:
st.exception(f'Exception: {e}')
else:
st.title('Sinhala Text generation with Finetuned GPT2')
st.markdown('This model has been finetuned Sinhala-gpt2 model with 6000 news articles(~12MB)')
model, tokenizer = load_model('keshan/sinhala-gpt2-newswire')
if gen_bt:
try:
with st.spinner('Generating...'):
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
seqs = generate(model, tokenizer, seed, seq_num, max_len)
for i, seq in enumerate(seqs):
st.info(f'Generated sequence {i+1}:')
st.write(seq)
st.info(f'English translation (by Google Translation):')
st.write(translator.translate(seq, src='si', dest='en').text)
except Exception as e:
st.exception(f'Exception: {e}')
st.markdown('____________')
|