musicgen-small / handler.py
pbotsaris's picture
added temp sine wave to test base64 encoding
edf8016
raw
history blame
3.42 kB
from typing import Dict, List, Any
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import io
import base64
import wave
import array
import math
def generate_sine_wave(freq, duration, sample_rate, amplitude):
n_samples = int(sample_rate * duration)
samples = []
for x in range(n_samples):
value = amplitude * math.sin(2 * math.pi * freq * x / sample_rate)
samples.append(int(value)) # rounding to the nearest integer
return array.array("h", samples) # array of short integers
def sine_to_base64():
frequency = 440.0 # Frequency in Hz
duration = 1.0 # seconds
volume = 0.5 # 0.0 to 1.0
sample_rate = 44100
amplitude = int(volume * 32767) # 16-bit audio
sine_wave = generate_sine_wave(frequency, duration, sample_rate, amplitude)
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "w") as wav_file:
n_channels = 1
sampwidth = 2
n_frames = len(sine_wave)
comptype = "NONE"
compname = "not compressed"
wav_file.setparams((n_channels, sampwidth, int(sample_rate), n_frames, comptype, compname))
wav_file.writeframes(sine_wave.tobytes())
base64_string = base64.b64encode(wav_buffer.getvalue()).decode('utf-8')
return base64_string
def create_params(params, fr):
# default
out = { "do_sample": True,
"guidance_scale": 3,
"max_new_tokens": 256
}
has_tokens = False
if params is None:
return out
if 'duration' in params:
out['max_new_tokens'] = params['duration'] * fr
has_tokens = True
for k, p in params.items():
if k in out:
if has_tokens and k == 'max_new_tokens':
continue
out[k] = p
return out
class EndpointHandler:
def __init__(self, path="pbotsaris/musicgen-small"):
# load model and processor
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.model.to('cuda')
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
Returns: wav file in bytes
"""
# inputs = data.pop("inputs", data)
# params = data.pop("parameters", None)
# inputs = self.processor(
# text=[inputs],
# padding=True,
# return_tensors="pt"
# ).to('cuda')
# params = create_params(params, self.model.config.audio_encoder.frame_rate)
# with torch.cuda.amp.autocast():
# outputs = self.model.generate(**inputs, **params)
# pred = outputs[0, 0].cpu().numpy()
# sr = 32000
# try:
# sr = self.model.config.audio_encoder.sampling_rate
# except:
# sr = 32000
# wav_buffer = io.BytesIO()
# wavfile.write(wav_buffer, rate=sr, data=pred)
# wav_data = wav_buffer.getvalue()
# base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
base64_encoded_wav = sine_to_base64()
return [{"audio": base64_encoded_wav}]
if __name__ == "__main__":
handler = EndpointHandler()