{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcbbe26c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b451ab22",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from datasets import load_dataset\n",
    "from IPython.display import Audio\n",
    "from diffusers import AutoencoderKL\n",
    "from audiodiffusion.mel import Mel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "324cef44",
   "metadata": {},
   "outputs": [],
   "source": [
    "mel = Mel()\n",
    "vae = AutoencoderKL.from_pretrained('../models/autoencoder-kl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da55ce79",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fea99ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset('teticio/audio-diffusion-256')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "426c6edd",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = random.choice(ds['train'])['image']\n",
    "display(image)\n",
    "Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d123f8a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# encode\n",
    "input_image = np.frombuffer(image.convert('RGB').tobytes(), dtype=\"uint8\").reshape(\n",
    "    (image.height, image.width, 3))\n",
    "input_image = ((input_image / 255) * 2 - 1).transpose(2, 0, 1)\n",
    "posterior = vae.encode(torch.tensor([input_image], dtype=torch.float32)).latent_dist\n",
    "latents = posterior.sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "482c458f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reconstruct\n",
    "output_image = vae.decode(latents)['sample']\n",
    "output_image = torch.clamp(output_image, -1., 1.)\n",
    "output_image = (output_image + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w\n",
    "output_image = (output_image.detach().cpu().numpy() *\n",
    "                255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n",
    "output_image = Image.fromarray(output_image).convert('L')\n",
    "display(output_image)\n",
    "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f10db020",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "output_image = vae.decode(torch.randn_like(posterior.sample()))['sample']\n",
    "output_image = torch.clamp(output_image, -1., 1.)\n",
    "output_image = (output_image + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w\n",
    "output_image = (output_image.detach().cpu().numpy() *\n",
    "                255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n",
    "output_image = Image.fromarray(output_image).convert('L')\n",
    "display(output_image)\n",
    "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46019770",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "huggingface",
   "language": "python",
   "name": "huggingface"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}