File size: 1,222 Bytes
ae49b72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers import AutoTokenizer, AutoModelForTextToWaveform
import torch
def init_musicgen(model_size='facebook/musicgen-small'):
"""
Initialize the MusicGen model and tokenizer.
"""
try:
tokenizer = AutoTokenizer.from_pretrained(model_size)
model = AutoModelForTextToWaveform.from_pretrained(model_size)
return tokenizer, model
except Exception as e:
print(f"Error loading model or tokenizer: {str(e)}")
raise e
def generate_music(text_prompt, tokenizer, model):
"""
Generate music based on the input text prompt using the pre-initialized model.
"""
print(f"Generating music for prompt: {text_prompt}")
# Tokenize the input text
inputs = tokenizer(text_prompt, return_tensors='pt')
# Check if inputs are valid
if 'input_ids' not in inputs or inputs['input_ids'] is None or inputs['input_ids'].size(0) == 0:
raise ValueError("Tokenized inputs are empty or invalid. Ensure your input prompt is valid.")
# Generate the music waveform
with torch.no_grad():
generated_waveform = model.generate(**inputs)
return generated_waveform.cpu().numpy().squeeze() |