{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch as T\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torchaudio\n", "from utils import load_ckpt, print_colored\n", "from tokenizer import make_tokenizer\n", "from model import get_hertz_dev_config\n", "import matplotlib.pyplot as plt\n", "from IPython.display import Audio, display\n", "\n", "\n", "# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n", "# you need to install PyTorch with the correct CUDA version. Run:\n", "# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n", "\n", "device = 'cuda' if T.cuda.is_available() else 'cpu'\n", "T.cuda.set_device(0)\n", "print_colored(f\"Using device: {device}\", \"grey\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This code will automatically download them if it can't find them.\n", "audio_tokenizer = make_tokenizer(device)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# We have different checkpoints for the single-speaker and two-speaker models\n", "# Set to True to load and run inference with the two-speaker model\n", "TWO_SPEAKER = False\n", "USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it.\n", "assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION)\n", "\n", "generator = model_config()\n", "generator = generator.eval().to(T.bfloat16).to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_and_preprocess_audio(audio_path):\n", " print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n", " # Load audio file\n", " audio_tensor, sr = torchaudio.load(audio_path)\n", " print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n", " \n", " if TWO_SPEAKER:\n", " if audio_tensor.shape[0] == 1:\n", " print_colored(\"Converting mono to stereo...\", \"grey\")\n", " audio_tensor = audio_tensor.repeat(2, 1)\n", " print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n", " else:\n", " if audio_tensor.shape[0] == 2:\n", " print_colored(\"Converting stereo to mono...\", \"grey\")\n", " audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n", " print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n", " \n", " # Resample to 16kHz if needed\n", " if sr != 16000:\n", " print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n", " resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n", " audio_tensor = resampler(audio_tensor)\n", " \n", " # Clip to 5 minutes if needed\n", " max_samples = 16000 * 60 * 5\n", " if audio_tensor.shape[1] > max_samples:\n", " print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n", " audio_tensor = audio_tensor[:, :max_samples]\n", "\n", " \n", " print_colored(\"Audio preprocessing complete!\", \"green\")\n", " return audio_tensor.unsqueeze(0)\n", "\n", "def display_audio(audio_tensor):\n", " audio_tensor = audio_tensor.cpu().squeeze()\n", " if audio_tensor.ndim == 1:\n", " audio_tensor = audio_tensor.unsqueeze(0)\n", " audio_tensor = audio_tensor.float()\n", "\n", " # Make a waveform plot\n", " plt.figure(figsize=(4, 1))\n", " plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n", " plt.axis('off')\n", " plt.show()\n", "\n", " # Make an audio player\n", " display(Audio(audio_tensor.numpy(), rate=16000))\n", " print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n", " \n", " \n", "\n", "# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n", "prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n", "display_audio(prompt_audio)\n", "prompt_len_seconds = 3\n", "prompt_len = prompt_len_seconds * 8" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print_colored(\"Encoding prompt...\", \"blue\")\n", "with T.autocast(device_type='cuda', dtype=T.bfloat16):\n", " if TWO_SPEAKER:\n", " encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n", " encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n", " encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n", " else:\n", " encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n", "print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n", "print_colored(\"Prompt encoded successfully!\", \"green\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_completion(encoded_prompt_audio, prompt_len, gen_len=None):\n", " prompt_len_seconds = prompt_len / 8\n", " print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n", " print_colored(\"Completing audio...\", \"blue\")\n", " encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len]\n", " with T.autocast(device_type='cuda', dtype=T.bfloat16):\n", " completed_audio_batch = generator.completion(\n", " encoded_prompt_audio, \n", " temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n", " use_cache=True,\n", " gen_len=gen_len)\n", "\n", " completed_audio = completed_audio_batch\n", " print_colored(f\"Decoding completion...\", \"blue\")\n", " if TWO_SPEAKER:\n", " decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n", " decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n", " decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n", " else:\n", " decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n", " print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n", "\n", " print_colored(\"Preparing audio for playback...\", \"blue\")\n", "\n", " audio_tensor = decoded_completion.cpu().squeeze()\n", " if audio_tensor.ndim == 1:\n", " audio_tensor = audio_tensor.unsqueeze(0)\n", " audio_tensor = audio_tensor.float()\n", "\n", " if audio_tensor.abs().max() > 1:\n", " audio_tensor = audio_tensor / audio_tensor.abs().max()\n", "\n", " return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n", "\n", "num_completions = 10\n", "print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n", "for _ in range(num_completions):\n", " completion = get_completion(encoded_prompt_audio, prompt_len, gen_len=20*8) # 20 seconds of generation\n", " display_audio(completion)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }