Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| from functools import cache | |
| from pathlib import Path | |
| from models.audio_spectrogram_transformer import AST, ASTExtractorWrapper | |
| from models.training_environment import TrainingEnvironment | |
| import torch | |
| from torch import nn | |
| import yaml | |
| import torchaudio | |
| CONFIG_FILE = Path("models/config/train_local.yaml") | |
| MODEL_CLS = AST | |
| EXTRACTOR = ASTExtractorWrapper | |
| class DancePredictor: | |
| def __init__( | |
| self, | |
| weight_path: str, | |
| labels: list[str], | |
| expected_duration=6, | |
| threshold=0.5, | |
| resample_frequency=16000, | |
| device="cpu", | |
| ): | |
| super().__init__() | |
| self.expected_duration = expected_duration | |
| self.threshold = threshold | |
| self.resample_frequency = resample_frequency | |
| self.labels = np.array(labels) | |
| self.device = device | |
| self.model = self.get_model(weight_path) | |
| self.extractor = ASTExtractorWrapper() | |
| def get_model(self, weight_path: str) -> nn.Module: | |
| weights = torch.load(weight_path, map_location=self.device)["state_dict"] | |
| model = AST(self.labels).to(self.device) | |
| for key in list(weights): | |
| weights[ | |
| key.replace( | |
| "model.", | |
| "", | |
| ) | |
| ] = weights.pop(key) | |
| model.load_state_dict(weights, strict=False) | |
| return model.to(self.device).eval() | |
| def from_config(cls, config_path: str) -> "DancePredictor": | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| weight_path = config["checkpoint"] | |
| labels = sorted(config["dance_ids"]) | |
| expected_duration = 6 | |
| threshold = 0.5 | |
| resample_frequency = 16000 | |
| device = "mps" | |
| return DancePredictor( | |
| weight_path, | |
| labels, | |
| expected_duration, | |
| threshold, | |
| resample_frequency, | |
| device, | |
| ) | |
| def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]: | |
| if waveform.ndim == 1: | |
| waveform = np.stack([waveform, waveform]).T | |
| waveform = torch.from_numpy(waveform.T) | |
| waveform = torchaudio.functional.apply_codec( | |
| waveform, sample_rate, "wav", channels_first=True | |
| ) | |
| waveform = torchaudio.functional.resample( | |
| waveform, sample_rate, self.resample_frequency | |
| ) | |
| waveform = waveform[ | |
| :, : self.resample_frequency * self.expected_duration | |
| ] # TODO PAD | |
| features = self.extractor(waveform) | |
| features = features.unsqueeze(0).to(self.device) | |
| results = self.model(features) | |
| results = nn.functional.softmax(results.squeeze(0), dim=0) | |
| results = results.detach().cpu().numpy() | |
| result_mask = results > self.threshold | |
| probs = results[result_mask] | |
| dances = self.labels[result_mask] | |
| return {dance: float(prob) for dance, prob in zip(dances, probs)} | |
| def get_model(config_path: str) -> DancePredictor: | |
| model = DancePredictor.from_config(config_path) | |
| return model | |
| def predict(audio: tuple[int, np.ndarray]) -> list[str]: | |
| sample_rate, waveform = audio | |
| model = get_model(CONFIG_FILE) | |
| results = model(waveform, sample_rate) | |
| return results if len(results) else "Dance Not Found" | |
| def demo(): | |
| title = "Dance Classifier" | |
| description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!" | |
| song_samples = Path(os.path.dirname(__file__), "assets", "song-samples") | |
| example_audio = [ | |
| str(song) for song in song_samples.iterdir() if song.name[0] != "." | |
| ] | |
| all_dances = get_model(CONFIG_FILE).labels | |
| recording_interface = gr.Interface( | |
| fn=predict, | |
| description="Record at least **6 seconds** of the song.", | |
| inputs=gr.Audio(source="microphone", label="Song Recording"), | |
| outputs=gr.Label(label="Dances"), | |
| examples=example_audio, | |
| ) | |
| uploading_interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Audio(label="Song Audio File"), | |
| outputs=gr.Label(label="Dances"), | |
| examples=example_audio, | |
| ) | |
| with gr.Blocks() as app: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| gr.TabbedInterface( | |
| [uploading_interface, recording_interface], ["Upload Song", "Record Song"] | |
| ) | |
| with gr.Accordion("See all dances", open=False): | |
| gr.Markdown("\n".join(f"- {dance}" for dance in all_dances)) | |
| return app | |
| if __name__ == "__main__": | |
| demo().launch() | |