{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "M74Gs_TjYl_B"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github"
      },
      "source": [
        "### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
        "\n",
        "[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
        "\n",
        "Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
        "\n",
        "Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
        "\n",
        "CVPR 2023\n",
        "\n",
        "TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "kA89DV-sKS4i"
      },
      "source": [
        "Installation (around 5 mins)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qJ4CplXsYl_E"
      },
      "outputs": [],
      "source": [
        "### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
        "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mdq6j4E5KQAR"
      },
      "outputs": [],
      "source": [
        "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n",
        "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n",
        "!sudo apt install python3.8\n",
        "\n",
        "!sudo apt-get install python3.8-distutils\n",
        "\n",
        "!python --version\n",
        "\n",
        "!apt-get update\n",
        "\n",
        "!apt install software-properties-common\n",
        "\n",
        "!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
        "\n",
        "!apt-get install python3-pip\n",
        "\n",
        "print('Git clone project and install requirements...')\n",
        "!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
        "%cd SadTalker\n",
        "!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n",
        "!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
        "!apt update\n",
        "!apt install ffmpeg &> /dev/null\n",
        "!python3.8 -m pip install -r requirements.txt"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "DddcKB_nKsnk"
      },
      "source": [
        "Download models (1 mins)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eDw3_UN8K2xa"
      },
      "outputs": [],
      "source": [
        "print('Download pre-trained models...')\n",
        "!rm -rf checkpoints\n",
        "!bash scripts/download_models.sh"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kK7DYeo7Yl_H"
      },
      "outputs": [],
      "source": [
        "# borrow from makeittalk\n",
        "import ipywidgets as widgets\n",
        "import glob\n",
        "import matplotlib.pyplot as plt\n",
        "print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
        "img_list = glob.glob1('examples/source_image', '*.png')\n",
        "img_list.sort()\n",
        "img_list = [item.split('.')[0] for item in img_list]\n",
        "default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
        "def on_change(change):\n",
        "    if change['type'] == 'change' and change['name'] == 'value':\n",
        "        plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
        "        plt.axis('off')\n",
        "        plt.show()\n",
        "default_head_name.observe(on_change)\n",
        "display(default_head_name)\n",
        "plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
        "plt.axis('off')\n",
        "plt.show()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "-khNZcnGK4UK"
      },
      "source": [
        "Animation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ToBlDusjK5sS"
      },
      "outputs": [],
      "source": [
        "# selected audio from exmaple/driven_audio\n",
        "img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
        "print(img)\n",
        "!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
        "           --source_image {img} \\\n",
        "           --result_dir ./results --still --preprocess full --enhancer gfpgan"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fAjwGmKKYl_I"
      },
      "outputs": [],
      "source": [
        "# visualize code from makeittalk\n",
        "from IPython.display import HTML\n",
        "from base64 import b64encode\n",
        "import os, sys\n",
        "\n",
        "# get the last from results\n",
        "\n",
        "results = sorted(os.listdir('./results/'))\n",
        "\n",
        "mp4_name = glob.glob('./results/*.mp4')[0]\n",
        "\n",
        "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
        "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
        "\n",
        "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <video width=256 controls>\n",
        "        <source src=\"%s\" type=\"video/mp4\">\n",
        "  </video>\n",
        "  \"\"\" % data_url))\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.9.7"
    },
    "vscode": {
      "interpreter": {
        "hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}