Spaces:
Runtime error
Runtime error
| # Copyright 2022 Tristan Behrens. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| from flask import Flask, render_template, request, send_file, jsonify, redirect, url_for | |
| from PIL import Image | |
| import os | |
| import io | |
| import random | |
| import base64 | |
| import torch | |
| import wave | |
| from source.logging import create_logger | |
| from source.tokensequence import token_sequence_to_audio, token_sequence_to_image | |
| from source import constants | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| logger = create_logger(__name__) | |
| # Load the auth-token from authtoken.txt. | |
| auth_token = os.getenv("authtoken") | |
| # Loading the model and its tokenizer. | |
| logger.info("Loading tokenizer and model...") | |
| tokenizer = AutoTokenizer.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token) | |
| model = AutoModelForCausalLM.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token) | |
| logger.info("Done.") | |
| # Create the app. | |
| logger.info("Creating app...") | |
| app = Flask(__name__, static_url_path="") | |
| logger.info("Done.") | |
| # Route for the loading page. | |
| def index(): | |
| return render_template( | |
| "index.html", | |
| compose_styles=constants.get_compose_styles_for_ui(), | |
| densities=constants.get_densities_for_ui(), | |
| temperatures=constants.get_temperatures_for_ui(), | |
| ) | |
| def compose(): | |
| # Get the parameters as JSON. | |
| params = request.get_json() | |
| music_style = params["music_style"] | |
| density = params["density"] | |
| temperature = params["temperature"] | |
| instruments = constants.get_instruments(music_style) | |
| density = constants.get_density(density) | |
| temperature = constants.get_temperature(temperature) | |
| print(f"instruments: {instruments} density: {density} temperature: {temperature}") | |
| # Generate with the given parameters. | |
| logger.info(f"Generating token sequence...") | |
| generated_sequence = generate_sequence(instruments, density, temperature) | |
| logger.info(f"Generated token sequence: {generated_sequence}") | |
| # Get the audio data as a array of int16. | |
| logger.info("Generating audio...") | |
| sample_rate, audio_data = token_sequence_to_audio(generated_sequence) | |
| logger.info(f"Done. Audio data: {len(audio_data)}") | |
| # Encode the audio-data as wave file in memory. Use the wave module. | |
| audio_data_bytes = io.BytesIO() | |
| wave_file = wave.open(audio_data_bytes, "wb") | |
| wave_file.setframerate(sample_rate) | |
| wave_file.setnchannels(1) | |
| wave_file.setsampwidth(2) | |
| wave_file.writeframes(audio_data) | |
| wave_file.close() | |
| # Return the audio-data as a base64-encoded string. | |
| audio_data_bytes.seek(0) | |
| audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8") | |
| audio_data_bytes.close() | |
| # Convert the audio data to an PIL image. | |
| image = token_sequence_to_image(generated_sequence) | |
| # Save PIL image to harddrive as PNG. | |
| logger.debug(f"Saving image to harddrive... {type(image)}") | |
| image_file_name = "compose.png" | |
| image.save(image_file_name, "PNG") | |
| # Save image to virtual file. | |
| img_io = io.BytesIO() | |
| image.save(img_io, "PNG", quality=70) | |
| img_io.seek(0) | |
| # Return the image as a base64-encoded string. | |
| image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8") | |
| img_io.close() | |
| # Return. | |
| return jsonify({ | |
| "tokens": generated_sequence, | |
| "audio": "data:audio/wav;base64," + audio_data_base64, | |
| "image": "data:image/png;base64," + image_data_base64, | |
| "status": "OK" | |
| }) | |
| def generate_sequence(instruments, density, temperature): | |
| instruments = instruments[::] | |
| random.shuffle(instruments) | |
| generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0] | |
| for instrument in instruments: | |
| more_ids = tokenizer.encode(f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt")[0] | |
| generated_ids = torch.cat((generated_ids, more_ids)) | |
| generated_ids = generated_ids.unsqueeze(0) | |
| generated_ids = model.generate( | |
| generated_ids, | |
| max_length=2048, | |
| do_sample=True, | |
| temperature=temperature, | |
| eos_token_id=tokenizer.encode("TRACK_END")[0] | |
| )[0] | |
| generated_sequence = tokenizer.decode(generated_ids) | |
| return generated_sequence | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |