|
from fastapi import FastAPI |
|
import streamlit as st |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
StoppingCriteriaList, |
|
MaxLengthCriteria, |
|
) |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
def song_generator(input_prompt): |
|
tokenizer = AutoTokenizer.from_pretrained("./TaylorSwiftFineTunedModel/") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("./TaylorSwiftFineTunedModel/") |
|
|
|
model.config.pad_token_id = model.config.eos_token_id |
|
input_ids = tokenizer(input_prompt, return_tensors="pt") |
|
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=300)]) |
|
outputs = model.contrastive_search( |
|
**input_ids, penalty_alpha=0.6, top_k=15, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, |
|
) |
|
song_generated = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
return song_generated |
|
|
|
st.title('Taylor-swift style song generator') |
|
|
|
st.header('Song generation Model') |
|
query = st.text_input("Enter 2 or 3 verses ", "") |
|
submit = st.button('Generate') |
|
input_song = query |
|
|
|
if submit: |
|
st.subheader('Song generated is ') |
|
|
|
with st.spinner(text='This may take a moment...'): |
|
output_sentence = song_generator(input_song) |
|
st.write(output_sentence[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|