|
from typing import Dict, List, Any |
|
from scipy.io import wavfile |
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
import torch |
|
import io |
|
import base64 |
|
import wave |
|
import array |
|
import math |
|
|
|
def generate_sine_wave(freq, duration, sample_rate, amplitude): |
|
n_samples = int(sample_rate * duration) |
|
samples = [] |
|
|
|
for x in range(n_samples): |
|
value = amplitude * math.sin(2 * math.pi * freq * x / sample_rate) |
|
samples.append(int(value)) |
|
|
|
return array.array("h", samples) |
|
|
|
|
|
def sine_to_base64(): |
|
frequency = 440.0 |
|
duration = 1.0 |
|
volume = 0.5 |
|
sample_rate = 44100 |
|
amplitude = int(volume * 32767) |
|
|
|
sine_wave = generate_sine_wave(frequency, duration, sample_rate, amplitude) |
|
|
|
wav_buffer = io.BytesIO() |
|
with wave.open(wav_buffer, "w") as wav_file: |
|
n_channels = 1 |
|
sampwidth = 2 |
|
n_frames = len(sine_wave) |
|
comptype = "NONE" |
|
compname = "not compressed" |
|
wav_file.setparams((n_channels, sampwidth, int(sample_rate), n_frames, comptype, compname)) |
|
wav_file.writeframes(sine_wave.tobytes()) |
|
|
|
base64_string = base64.b64encode(wav_buffer.getvalue()).decode('utf-8') |
|
return base64_string |
|
|
|
|
|
def create_params(params, fr): |
|
|
|
out = { "do_sample": True, |
|
"guidance_scale": 3, |
|
"max_new_tokens": 256 |
|
} |
|
|
|
has_tokens = False |
|
|
|
if params is None: |
|
return out |
|
|
|
if 'duration' in params: |
|
out['max_new_tokens'] = params['duration'] * fr |
|
has_tokens = True |
|
|
|
for k, p in params.items(): |
|
if k in out: |
|
if has_tokens and k == 'max_new_tokens': |
|
continue |
|
|
|
out[k] = p |
|
|
|
return out |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path="pbotsaris/musicgen-small"): |
|
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16) |
|
self.model.to('cuda') |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt and generation parameters. |
|
|
|
Returns: wav file in bytes |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base64_encoded_wav = sine_to_base64() |
|
return [{"audio": base64_encoded_wav}] |
|
|
|
|
|
if __name__ == "__main__": |
|
handler = EndpointHandler() |
|
|