from fastapi import FastAPI import streamlit as st from transformers import ( AutoTokenizer, AutoConfig, AutoModelForCausalLM, StoppingCriteriaList, MaxLengthCriteria, ) app = FastAPI() #input_prompt = "Heart is in love" def song_generator(input_prompt): tokenizer = AutoTokenizer.from_pretrained("./TaylorSwiftFineTunedModel/") model = AutoModelForCausalLM.from_pretrained("./TaylorSwiftFineTunedModel/") # Set pad_token_id model.config.pad_token_id = model.config.eos_token_id # Tokenize input input_ids = tokenizer(input_prompt, return_tensors="pt") # Generate with proper parameters with torch.no_grad(): # Add this for memory efficiency outputs = model.generate( input_ids, max_length=300, do_sample=True, temperature=0.7, top_k=15, top_p=0.9, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2 ) # Decode the output song_generated = tokenizer.decode(outputs[0], skip_special_tokens=True) return song_generated st.title('Taylor-swift style song generator') st.header('Song generation Model') query = st.text_input("Enter 2 or 3 verses ", "") submit = st.button('Generate') input_song = query if submit: st.subheader('Song generated is ') with st.spinner(text='This may take a moment...'): output_sentence = song_generator(input_song) st.write(output_sentence) #output = song_generator(input_prompt) #@app.get("/") #def read_root(): # return output