diff --git "a/EAT_v2.ipynb" "b/EAT_v2.ipynb"
new file mode 100644--- /dev/null
+++ "b/EAT_v2.ipynb"
@@ -0,0 +1,2659 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title set font and wrap\n",
+ "from IPython.display import HTML, display\n",
+ "\n",
+ "def set_css():\n",
+ " display(HTML('''\n",
+ " \n",
+ " '''))\n",
+ "get_ipython().events.register('pre_run_cell', set_css)"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "2eK7PTTQiAWv"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LBBjwIn0ROsF"
+ },
+ "outputs": [],
+ "source": [
+ "#!git clone https://github.com/yuangan/EAT_code.git\n",
+ "!wget https://huggingface.co/waveydaveygravy/styletalk/resolve/main/EAT_code.zip\n",
+ "!unzip /content/EAT_code.zip\n",
+ "%cd /content/EAT_code\n",
+ "!pip install -r requirements.txt\n",
+ "!pip install resampy\n",
+ "#!pip install face_detection\n",
+ "!pip install python_speech_features\n",
+ "#@title make directories if not done yet\n",
+ "!mkdir tensorflow\n",
+ "!mkdir Results\n",
+ "!mkdir ckpt\n",
+ "!mkdir demo\n",
+ "!mkdir Utils\n",
+ "%cd /content/EAT_code/tensorflow\n",
+ "!mkdir models\n",
+ "%cd /content/EAT_code\n",
+ "print(\"done\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title in /content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml change batch_size to 1 line 70"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "_Mxxg6d-WakT"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title make directories if not done yet\n",
+ "!mkdir tensorflow\n",
+ "!mkdir ckpt\n",
+ "!mkdir demo\n",
+ "!mkdir Utils\n",
+ "%cd /content/EAT_code/demo\n",
+ "!mkdir imgs1\n",
+ "!mkdir imgs_cropped1\n",
+ "!mkdir imgs_latent1\n",
+ "%cd /content/EAT_code/tensorflow\n",
+ "!mkdir models\n",
+ "%cd /content/EAT_code"
+ ],
+ "metadata": {
+ "id": "3bf2Yff2DEdH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title donwload and place deepspeech model\n",
+ "%cd /content/EAT_code\n",
+ "!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n",
+ "!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n",
+ "!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils\n",
+ "%cd /content/EAT_code/tensorflow/models\n",
+ "!wget https://github.com/osmr/deepspeech_features/releases/download/v0.0.1/deepspeech-0_1_0-b90017e8.pb.zip\n",
+ "%cd /content/EAT_code"
+ ],
+ "metadata": {
+ "id": "ucK4MQbq0yHx"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title download models\n",
+ "%cd /content/EAT_code\n",
+ "\n",
+ "!gdown --id 1KK15n2fOdfLECWN5wvX54mVyDt18IZCo && unzip -q ckpt.zip -d ckpt\n",
+ "!gdown --id 1MeFGC7ig-vgpDLdhh2vpTIiElrhzZmgT && unzip -q demo.zip -d demo\n",
+ "!gdown --id 1HGVzckXh-vYGZEUUKMntY1muIbkbnRcd && unzip -q Utils.zip -d Utils"
+ ],
+ "metadata": {
+ "id": "eDDFgToSSEgj",
+ "cellView": "form"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title upload custom mp4 to videos\n",
+ "%cd /content/EAT_code/preprocess/video\n",
+ "from google.colab import files\n",
+ "uploaded = files.upload()\n",
+ "%cd /content/EAT_code"
+ ],
+ "metadata": {
+ "id": "rlfKO73uUXFT",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 106
+ },
+ "outputId": "e8116307-0a3f-4eea-ebd1-925f3e17a4bf"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/EAT_code/preprocess/video\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Saving bo_1resized.mp4 to bo_1resized.mp4\n",
+ "/content/EAT_code\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title extract boundary boxes\n",
+ "!python /content/EAT_code/preprocess/extract_bbox.py"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "llzj0RprSPu6"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title place custom video in preprocess/video\n",
+ "%cd /content/EAT_code/preprocess\n",
+ "!python /content/EAT_code/preprocess/preprocess_video.py # --deepspeech \"/content/EAT_code/tensorflow/models/deepspeech-0_1_0-b90017e8.pb\""
+ ],
+ "metadata": {
+ "id": "p99eLmnjW-e1"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ " print(f\"Number of paths: {len(self.A_paths)}\")\n",
+ " print(\"Paths:\")\n",
+ " for path in self.A_paths:\n",
+ " print(path)"
+ ],
+ "metadata": {
+ "id": "B4LWLwEvJOsp"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "# Replace \"path/to/your/file.npy\" with the actual path to your file\n",
+ "data = np.load(\"/content/EAT_code/demo/video_processed/obama/latent_evp_25/obama.npy\")\n",
+ "\n",
+ "print(f\"Type: {type(data)}\")\n",
+ "print(f\"Shape: {data.shape}\")"
+ ],
+ "metadata": {
+ "id": "V-hPYJBgYIKx"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!python /content/EAT_code/demo.py --root_wav /content/EAT_code/demo/video_processed1/bo_1resized --emo ang --save_dir /content/EAT_code/Results"
+ ],
+ "metadata": {
+ "id": "APJUoQkgaqa9",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "outputId": "e8150705-1bae-4420-8306-069aa274fb1d"
+ },
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "deepprompt_eam3d_all_final_313\n",
+ "cuda is available\n",
+ "/usr/local/lib/python3.10/dist-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2228.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ " 0% 0/1 [00:00, ?it/s]\n",
+ " 0% 0/1 [00:00, ?it/s]\u001b[A/content/EAT_code/modules/model_transformer.py:158: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+ " pred = F.softmax(pred)\n",
+ "\n",
+ "\n",
+ " 0% 0/177 [00:00, ?it/s]\u001b[A\u001b[A/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:4193: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
+ " warnings.warn(\n",
+ "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
+ " warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n",
+ "\n",
+ "\n",
+ " 1% 1/177 [00:00<00:57, 3.04it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 1% 2/177 [00:00<00:58, 3.01it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 2% 3/177 [00:00<00:58, 2.99it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 2% 4/177 [00:01<00:57, 2.99it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 3% 5/177 [00:01<00:57, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 3% 6/177 [00:02<00:57, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 4% 7/177 [00:02<00:57, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 5% 8/177 [00:02<00:56, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 5% 9/177 [00:03<00:56, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 6% 10/177 [00:03<00:56, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 6% 11/177 [00:03<00:55, 2.98it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 7% 12/177 [00:04<00:55, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 7% 13/177 [00:04<00:55, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 8% 14/177 [00:04<00:54, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 8% 15/177 [00:05<00:54, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 9% 16/177 [00:05<00:54, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 10% 17/177 [00:05<00:53, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 10% 18/177 [00:06<00:53, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 11% 19/177 [00:06<00:53, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 11% 20/177 [00:06<00:52, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 12% 21/177 [00:07<00:52, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 12% 22/177 [00:07<00:52, 2.97it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 13% 23/177 [00:07<00:51, 2.96it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 14% 24/177 [00:08<00:51, 2.96it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 14% 25/177 [00:08<00:51, 2.96it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 15% 26/177 [00:08<00:51, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 15% 27/177 [00:09<00:50, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 16% 28/177 [00:09<00:50, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 16% 29/177 [00:09<00:50, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 17% 30/177 [00:10<00:49, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 18% 31/177 [00:10<00:49, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 18% 32/177 [00:10<00:49, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 19% 33/177 [00:11<00:48, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 19% 34/177 [00:11<00:48, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 20% 35/177 [00:11<00:48, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 20% 36/177 [00:12<00:47, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 21% 37/177 [00:12<00:47, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 21% 38/177 [00:12<00:47, 2.95it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 22% 39/177 [00:13<00:46, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 23% 40/177 [00:13<00:46, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 23% 41/177 [00:13<00:46, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 24% 42/177 [00:14<00:45, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 24% 43/177 [00:14<00:45, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 25% 44/177 [00:14<00:45, 2.94it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 25% 45/177 [00:15<00:45, 2.93it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 26% 46/177 [00:15<00:44, 2.93it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 27% 47/177 [00:15<00:44, 2.93it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 27% 48/177 [00:16<00:44, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 28% 49/177 [00:16<00:43, 2.93it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 28% 50/177 [00:16<00:43, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 29% 51/177 [00:17<00:43, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 29% 52/177 [00:17<00:42, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 30% 53/177 [00:17<00:42, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 31% 54/177 [00:18<00:42, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 31% 55/177 [00:18<00:41, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 32% 56/177 [00:18<00:41, 2.92it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 32% 57/177 [00:19<00:41, 2.91it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 33% 58/177 [00:19<00:40, 2.91it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 33% 59/177 [00:20<00:40, 2.91it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 34% 60/177 [00:20<00:40, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 34% 61/177 [00:20<00:39, 2.91it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 35% 62/177 [00:21<00:39, 2.91it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 36% 63/177 [00:21<00:39, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 36% 64/177 [00:21<00:38, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 37% 65/177 [00:22<00:38, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 37% 66/177 [00:22<00:38, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 38% 67/177 [00:22<00:37, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 38% 68/177 [00:23<00:37, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 39% 69/177 [00:23<00:37, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 40% 70/177 [00:23<00:36, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 40% 71/177 [00:24<00:36, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 41% 72/177 [00:24<00:36, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 41% 73/177 [00:24<00:35, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 42% 74/177 [00:25<00:35, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 42% 75/177 [00:25<00:35, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 43% 76/177 [00:25<00:34, 2.91it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 44% 77/177 [00:26<00:34, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 44% 78/177 [00:26<00:34, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 45% 79/177 [00:26<00:33, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 45% 80/177 [00:27<00:33, 2.90it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 46% 81/177 [00:27<00:33, 2.89it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 46% 82/177 [00:27<00:32, 2.89it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 47% 83/177 [00:28<00:32, 2.89it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 47% 84/177 [00:28<00:32, 2.88it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 48% 85/177 [00:28<00:32, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 49% 86/177 [00:29<00:31, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 49% 87/177 [00:29<00:31, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 50% 88/177 [00:30<00:30, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 50% 89/177 [00:30<00:30, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 51% 90/177 [00:30<00:30, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 51% 91/177 [00:31<00:29, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 52% 92/177 [00:31<00:29, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 53% 93/177 [00:31<00:29, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 53% 94/177 [00:32<00:28, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 54% 95/177 [00:32<00:28, 2.87it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 54% 96/177 [00:32<00:28, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 55% 97/177 [00:33<00:27, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 55% 98/177 [00:33<00:27, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 56% 99/177 [00:33<00:27, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 56% 100/177 [00:34<00:26, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 57% 101/177 [00:34<00:26, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 58% 102/177 [00:34<00:26, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 58% 103/177 [00:35<00:25, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 59% 104/177 [00:35<00:25, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 59% 105/177 [00:35<00:25, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 60% 106/177 [00:36<00:24, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 60% 107/177 [00:36<00:24, 2.86it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 61% 108/177 [00:37<00:24, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 62% 109/177 [00:37<00:23, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 62% 110/177 [00:37<00:23, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 63% 111/177 [00:38<00:23, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 63% 112/177 [00:38<00:22, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 64% 113/177 [00:38<00:22, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 64% 114/177 [00:39<00:22, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 65% 115/177 [00:39<00:21, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 66% 116/177 [00:39<00:21, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 66% 117/177 [00:40<00:21, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 67% 118/177 [00:40<00:20, 2.85it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 67% 119/177 [00:40<00:20, 2.84it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 68% 120/177 [00:41<00:20, 2.84it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 68% 121/177 [00:41<00:19, 2.84it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 69% 122/177 [00:41<00:19, 2.83it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 69% 123/177 [00:42<00:19, 2.84it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 70% 124/177 [00:42<00:18, 2.83it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 71% 125/177 [00:43<00:18, 2.83it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 71% 126/177 [00:43<00:18, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 72% 127/177 [00:43<00:17, 2.82it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 72% 128/177 [00:44<00:17, 2.82it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 73% 129/177 [00:44<00:17, 2.82it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 73% 130/177 [00:44<00:16, 2.82it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 74% 131/177 [00:45<00:16, 2.82it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 75% 132/177 [00:45<00:15, 2.82it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 75% 133/177 [00:45<00:15, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 76% 134/177 [00:46<00:15, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 76% 135/177 [00:46<00:14, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 77% 136/177 [00:46<00:14, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 77% 137/177 [00:47<00:14, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 78% 138/177 [00:47<00:13, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 79% 139/177 [00:47<00:13, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 79% 140/177 [00:48<00:13, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 80% 141/177 [00:48<00:12, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 80% 142/177 [00:49<00:12, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 81% 143/177 [00:49<00:12, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 81% 144/177 [00:49<00:11, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 82% 145/177 [00:50<00:11, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 82% 146/177 [00:50<00:11, 2.81it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 83% 147/177 [00:50<00:10, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 84% 148/177 [00:51<00:10, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 84% 149/177 [00:51<00:09, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 85% 150/177 [00:51<00:09, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 85% 151/177 [00:52<00:09, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 86% 152/177 [00:52<00:08, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 86% 153/177 [00:52<00:08, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 87% 154/177 [00:53<00:08, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 88% 155/177 [00:53<00:07, 2.80it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 88% 156/177 [00:54<00:07, 2.79it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 89% 157/177 [00:54<00:07, 2.79it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 89% 158/177 [00:54<00:06, 2.79it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 90% 159/177 [00:55<00:06, 2.79it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 90% 160/177 [00:55<00:06, 2.79it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 91% 161/177 [00:55<00:05, 2.79it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 92% 162/177 [00:56<00:05, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 92% 163/177 [00:56<00:05, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 93% 164/177 [00:56<00:04, 2.78it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 93% 165/177 [00:57<00:04, 2.78it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 94% 166/177 [00:57<00:03, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 94% 167/177 [00:58<00:03, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 95% 168/177 [00:58<00:03, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 95% 169/177 [00:58<00:02, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 96% 170/177 [00:59<00:02, 2.78it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 97% 171/177 [00:59<00:02, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 97% 172/177 [00:59<00:01, 2.77it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 98% 173/177 [01:00<00:01, 2.76it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 98% 174/177 [01:00<00:01, 2.76it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 99% 175/177 [01:00<00:00, 2.76it/s]\u001b[A\u001b[A\n",
+ "\n",
+ " 99% 176/177 [01:01<00:00, 2.76it/s]\u001b[A\u001b[A\n",
+ "\n",
+ "100% 177/177 [01:01<00:00, 2.87it/s]\n",
+ "\n",
+ "100% 1/1 [01:04<00:00, 64.70s/it]\n",
+ "100% 1/1 [01:04<00:00, 64.70s/it]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content\n",
+ "!zip -r EAT_code.zip /EAT_code"
+ ],
+ "metadata": {
+ "id": "XZuRqjR0EGuY"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title demo.py with fixed paths\n",
+ "\n",
+ "import os\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import yaml\n",
+ "from modules.generator import OcclusionAwareSPADEGeneratorEam\n",
+ "from modules.keypoint_detector import KPDetector, HEEstimator\n",
+ "import argparse\n",
+ "import imageio\n",
+ "from modules.transformer import Audio2kpTransformerBBoxQDeepPrompt as Audio2kpTransformer\n",
+ "from modules.prompt import EmotionDeepPrompt, EmotionalDeformationTransformer\n",
+ "from scipy.io import wavfile\n",
+ "\n",
+ "from modules.model_transformer import get_rotation_matrix, keypoint_transformation\n",
+ "from skimage import io, img_as_float32\n",
+ "from skimage.transform import resize\n",
+ "import torchaudio\n",
+ "import soundfile as sf\n",
+ "from scipy.spatial import ConvexHull\n",
+ "\n",
+ "import torch.nn.functional as F\n",
+ "import glob\n",
+ "from tqdm import tqdm\n",
+ "import gzip\n",
+ "\n",
+ "emo_label = ['ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur']\n",
+ "emo_label_full = ['angry', 'contempt', 'disgusted', 'fear', 'happy', 'neutral', 'sad', 'surprised']\n",
+ "latent_dim = 16\n",
+ "\n",
+ "MEL_PARAMS_25 = {\n",
+ " \"n_mels\": 80,\n",
+ " \"n_fft\": 2048,\n",
+ " \"win_length\": 640,\n",
+ " \"hop_length\": 640\n",
+ "}\n",
+ "\n",
+ "to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS_25)\n",
+ "mean, std = -4, 4\n",
+ "\n",
+ "expU = torch.from_numpy(np.load('/content/EAT_code/expPCAnorm_fin/U_mead.npy')[:,:32])\n",
+ "expmean = torch.from_numpy(np.load('/content/EAT_code/expPCAnorm_fin/mean_mead.npy'))\n",
+ "\n",
+ "root_wav = './demo/video_processed/M003_neu_1_001'\n",
+ "def normalize_kp(kp_source, kp_driving, kp_driving_initial,\n",
+ " use_relative_movement=True, use_relative_jacobian=True):\n",
+ "\n",
+ " kp_new = {k: v for k, v in kp_driving.items()}\n",
+ " if use_relative_movement:\n",
+ " kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])\n",
+ " kp_new['value'] = kp_value_diff + kp_source['value']\n",
+ "\n",
+ " if use_relative_jacobian:\n",
+ " jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))\n",
+ " kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])\n",
+ "\n",
+ " return kp_new\n",
+ "\n",
+ "def _load_tensor(data):\n",
+ " wave_path = data\n",
+ " wave, sr = sf.read(wave_path)\n",
+ " wave_tensor = torch.from_numpy(wave).float()\n",
+ " return wave_tensor\n",
+ "\n",
+ "def build_model(config, device_ids=[0]):\n",
+ " generator = OcclusionAwareSPADEGeneratorEam(**config['model_params']['generator_params'],\n",
+ " **config['model_params']['common_params'])\n",
+ " if torch.cuda.is_available():\n",
+ " print('cuda is available')\n",
+ " generator.to(device_ids[0])\n",
+ "\n",
+ " kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n",
+ " **config['model_params']['common_params'])\n",
+ "\n",
+ " if torch.cuda.is_available():\n",
+ " kp_detector.to(device_ids[0])\n",
+ "\n",
+ "\n",
+ " audio2kptransformer = Audio2kpTransformer(**config['model_params']['audio2kp_params'], face_ea=True)\n",
+ "\n",
+ " if torch.cuda.is_available():\n",
+ " audio2kptransformer.to(device_ids[0])\n",
+ "\n",
+ " sidetuning = EmotionalDeformationTransformer(**config['model_params']['audio2kp_params'])\n",
+ "\n",
+ " if torch.cuda.is_available():\n",
+ " sidetuning.to(device_ids[0])\n",
+ "\n",
+ " emotionprompt = EmotionDeepPrompt()\n",
+ "\n",
+ " if torch.cuda.is_available():\n",
+ " emotionprompt.to(device_ids[0])\n",
+ "\n",
+ " return generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt\n",
+ "\n",
+ "\n",
+ "def prepare_test_data(img_path, audio_path, opt, emotype, use_otherimg=True):\n",
+ " # sr,_ = wavfile.read(audio_path)\n",
+ "\n",
+ " if use_otherimg:\n",
+ " source_latent = np.load(img_path.replace('cropped', 'latent')[:-4]+'.npy', allow_pickle=True)\n",
+ " else:\n",
+ " source_latent = np.load(img_path.replace('images', 'latent')[:-9]+'.npy', allow_pickle=True)\n",
+ " he_source = {}\n",
+ " for k in source_latent[1].keys():\n",
+ " he_source[k] = torch.from_numpy(source_latent[1][k][0]).unsqueeze(0).cuda()\n",
+ "\n",
+ " # source images\n",
+ " source_img = img_as_float32(io.imread(img_path)).transpose((2, 0, 1))\n",
+ " asp = os.path.basename(audio_path)[:-4]\n",
+ "\n",
+ " # latent code\n",
+ " y_trg = emo_label.index(emotype)\n",
+ " z_trg = torch.randn(latent_dim)\n",
+ "\n",
+ " # driving latent\n",
+ " latent_path_driving = f'{root_wav}/latent_evp_25/{asp}.npy'\n",
+ " pose_gz = gzip.GzipFile(f'{root_wav}/poseimg/{asp}.npy.gz', 'r')\n",
+ " poseimg = np.load(pose_gz)\n",
+ " deepfeature = np.load(f'{root_wav}/deepfeature32/{asp}.npy')\n",
+ " driving_latent = np.load(latent_path_driving[:-4]+'.npy', allow_pickle=True)\n",
+ " he_driving = driving_latent[1]\n",
+ "\n",
+ " # gt frame number\n",
+ " frames = glob.glob(f'{root_wav}/images_evp_25/cropped/*.jpg')\n",
+ " num_frames = len(frames)\n",
+ "\n",
+ " wave_tensor = _load_tensor(audio_path)\n",
+ " if len(wave_tensor.shape) > 1:\n",
+ " wave_tensor = wave_tensor[:, 0]\n",
+ " mel_tensor = to_melspec(wave_tensor)\n",
+ " mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std\n",
+ " name_len = min(mel_tensor.shape[1], poseimg.shape[0], deepfeature.shape[0])\n",
+ "\n",
+ " audio_frames = []\n",
+ " poseimgs = []\n",
+ " deep_feature = []\n",
+ "\n",
+ " pad, deep_pad = np.load('pad.npy', allow_pickle=True)\n",
+ "\n",
+ " if name_len < num_frames:\n",
+ " diff = num_frames - name_len\n",
+ " if diff > 2:\n",
+ " print(f\"Attention: the frames are {diff} more than name_len, we will use name_len to replace num_frames\")\n",
+ " num_frames=name_len\n",
+ " for k in he_driving.keys():\n",
+ " he_driving[k] = he_driving[k][:name_len, :]\n",
+ " for rid in range(0, num_frames):\n",
+ " audio = []\n",
+ " poses = []\n",
+ " deeps = []\n",
+ " for i in range(rid - opt['num_w'], rid + opt['num_w'] + 1):\n",
+ " if i < 0:\n",
+ " audio.append(pad)\n",
+ " poses.append(poseimg[0])\n",
+ " deeps.append(deep_pad)\n",
+ " elif i >= name_len:\n",
+ " audio.append(pad)\n",
+ " poses.append(poseimg[-1])\n",
+ " deeps.append(deep_pad)\n",
+ " else:\n",
+ " audio.append(mel_tensor[:, i])\n",
+ " poses.append(poseimg[i])\n",
+ " deeps.append(deepfeature[i])\n",
+ "\n",
+ " audio_frames.append(torch.stack(audio, dim=1))\n",
+ " poseimgs.append(poses)\n",
+ " deep_feature.append(deeps)\n",
+ " audio_frames = torch.stack(audio_frames, dim=0)\n",
+ " poseimgs = torch.from_numpy(np.array(poseimgs))\n",
+ " deep_feature = torch.from_numpy(np.array(deep_feature)).to(torch.float)\n",
+ " return audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving\n",
+ "\n",
+ "def load_ckpt(ckpt, kp_detector, generator, audio2kptransformer, sidetuning, emotionprompt):\n",
+ " checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))\n",
+ " if audio2kptransformer is not None:\n",
+ " audio2kptransformer.load_state_dict(checkpoint['audio2kptransformer'])\n",
+ " if generator is not None:\n",
+ " generator.load_state_dict(checkpoint['generator'])\n",
+ " if kp_detector is not None:\n",
+ " kp_detector.load_state_dict(checkpoint['kp_detector'])\n",
+ " if sidetuning is not None:\n",
+ " sidetuning.load_state_dict(checkpoint['sidetuning'])\n",
+ " if emotionprompt is not None:\n",
+ " emotionprompt.load_state_dict(checkpoint['emotionprompt'])\n",
+ "\n",
+ "import cv2\n",
+ "import dlib\n",
+ "from tqdm import tqdm\n",
+ "from skimage import transform as tf\n",
+ "detector = dlib.get_frontal_face_detector()\n",
+ "predictor = dlib.shape_predictor('/content/EAT_code/demo/shape_predictor_68_face_landmarks.dat')\n",
+ "\n",
+ "def shape_to_np(shape, dtype=\"int\"):\n",
+ " # initialize the list of (x, y)-coordinates\n",
+ " coords = np.zeros((shape.num_parts, 2), dtype=dtype)\n",
+ "\n",
+ " # loop over all facial landmarks and convert them\n",
+ " # to a 2-tuple of (x, y)-coordinates\n",
+ " for i in range(0, shape.num_parts):\n",
+ " coords[i] = (shape.part(i).x, shape.part(i).y)\n",
+ "\n",
+ " # return the list of (x, y)-coordinates\n",
+ " return coords\n",
+ "\n",
+ "def crop_image(image_path, out_path):\n",
+ " template = np.load('./demo/M003_template.npy')\n",
+ " image = cv2.imread(image_path)\n",
+ " gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
+ " rects = detector(gray, 1) #detect human face\n",
+ " if len(rects) != 1:\n",
+ " return 0\n",
+ " for (j, rect) in enumerate(rects):\n",
+ " shape = predictor(gray, rect) #detect 68 points\n",
+ " shape = shape_to_np(shape)\n",
+ "\n",
+ " pts2 = np.float32(template[:47,:])\n",
+ " pts1 = np.float32(shape[:47,:]) #eye and nose\n",
+ " tform = tf.SimilarityTransform()\n",
+ " tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n",
+ "\n",
+ " dst = tf.warp(image, tform, output_shape=(256, 256))\n",
+ "\n",
+ " dst = np.array(dst * 255, dtype=np.uint8)\n",
+ "\n",
+ " cv2.imwrite(out_path, dst)\n",
+ "\n",
+ "def preprocess_imgs(allimgs, tmp_allimgs_cropped):\n",
+ " name_cropped = []\n",
+ " for path in tmp_allimgs_cropped:\n",
+ " name_cropped.append(os.path.basename(path))\n",
+ " for path in allimgs:\n",
+ " if os.path.basename(path) in name_cropped:\n",
+ " continue\n",
+ " else:\n",
+ " out_path = path.replace('imgs/', 'imgs_cropped/')\n",
+ " crop_image(path, out_path)\n",
+ "\n",
+ "from sync_batchnorm import DataParallelWithCallback\n",
+ "def load_checkpoints_extractor(config_path, checkpoint_path, cpu=False):\n",
+ "\n",
+ " with open(config_path) as f:\n",
+ " config = yaml.load(f, Loader=yaml.FullLoader)\n",
+ "\n",
+ " kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n",
+ " **config['model_params']['common_params'])\n",
+ " if not cpu:\n",
+ " kp_detector.cuda()\n",
+ "\n",
+ " he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],\n",
+ " **config['model_params']['common_params'])\n",
+ " if not cpu:\n",
+ " he_estimator.cuda()\n",
+ "\n",
+ " if cpu:\n",
+ " checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n",
+ " else:\n",
+ " checkpoint = torch.load(checkpoint_path)\n",
+ "\n",
+ " kp_detector.load_state_dict(checkpoint['kp_detector'])\n",
+ " he_estimator.load_state_dict(checkpoint['he_estimator'])\n",
+ "\n",
+ " if not cpu:\n",
+ " kp_detector = DataParallelWithCallback(kp_detector)\n",
+ " he_estimator = DataParallelWithCallback(he_estimator)\n",
+ "\n",
+ " kp_detector.eval()\n",
+ " he_estimator.eval()\n",
+ "\n",
+ " return kp_detector, he_estimator\n",
+ "\n",
+ "def estimate_latent(driving_video, kp_detector, he_estimator):\n",
+ " with torch.no_grad():\n",
+ " predictions = []\n",
+ " driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).cuda()\n",
+ " kp_canonical = kp_detector(driving[:, :, 0])\n",
+ " he_drivings = {'yaw': [], 'pitch': [], 'roll': [], 't': [], 'exp': []}\n",
+ "\n",
+ " for frame_idx in range(driving.shape[2]):\n",
+ " driving_frame = driving[:, :, frame_idx]\n",
+ " he_driving = he_estimator(driving_frame)\n",
+ " for k in he_drivings.keys():\n",
+ " he_drivings[k].append(he_driving[k])\n",
+ " return [kp_canonical, he_drivings]\n",
+ "\n",
+ "def extract_keypoints(extract_list):\n",
+ " kp_detector, he_estimator = load_checkpoints_extractor(config_path='config/vox-256-spade.yaml', checkpoint_path='./ckpt/pretrain_new_274.pth.tar')\n",
+ " if not os.path.exists('./demo/imgs_latent/'):\n",
+ " os.makedirs('./demo/imgs_latent/')\n",
+ " for imgname in tqdm(extract_list):\n",
+ " path_frames = [imgname]\n",
+ " filesname=os.path.basename(imgname)[:-4]\n",
+ " if os.path.exists(f'./demo/imgs_latent/'+filesname+'.npy'):\n",
+ " continue\n",
+ " driving_frames = []\n",
+ " for im in path_frames:\n",
+ " driving_frames.append(imageio.imread(im))\n",
+ " driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_frames]\n",
+ "\n",
+ " kc, he = estimate_latent(driving_video, kp_detector, he_estimator)\n",
+ " kc = kc['value'].cpu().numpy()\n",
+ " for k in he:\n",
+ " he[k] = torch.cat(he[k]).cpu().numpy()\n",
+ " np.save('./demo/imgs_latent/'+filesname, [kc, he])\n",
+ "\n",
+ "def preprocess_cropped_imgs(allimgs_cropped):\n",
+ " extract_list = []\n",
+ " for img_path in allimgs_cropped:\n",
+ " if not os.path.exists(img_path.replace('cropped', 'latent')[:-4]+'.npy'):\n",
+ " extract_list.append(img_path)\n",
+ " if len(extract_list) > 0:\n",
+ " print('=========', \"Extract latent keypoints from New image\", '======')\n",
+ " extract_keypoints(extract_list)\n",
+ "\n",
+ "def test(ckpt, emotype, save_dir=\" \"):\n",
+ " # with open(\"config/vox-transformer2.yaml\") as f:\n",
+ " with open(\"/content/EAT_code/config/deepprompt_eam3d_st_tanh_304_3090_all.yaml\") as f:\n",
+ " config = yaml.load(f, Loader=yaml.FullLoader)\n",
+ " cur_path = os.getcwd()\n",
+ " generator, kp_detector, audio2kptransformer, sidetuning, emotionprompt = build_model(config)\n",
+ " load_ckpt(ckpt, kp_detector=kp_detector, generator=generator, audio2kptransformer=audio2kptransformer, sidetuning=sidetuning, emotionprompt=emotionprompt)\n",
+ "\n",
+ " audio2kptransformer.eval()\n",
+ " generator.eval()\n",
+ " kp_detector.eval()\n",
+ " sidetuning.eval()\n",
+ " emotionprompt.eval()\n",
+ "\n",
+ " all_wavs2 = [f'{root_wav}/{os.path.basename(root_wav)}.wav']\n",
+ " allimg = glob.glob('./demo/imgs/*.jpg')\n",
+ " tmp_allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n",
+ " preprocess_imgs(allimg, tmp_allimg_cropped) # crop and align images\n",
+ "\n",
+ " allimg_cropped = glob.glob('./demo/imgs_cropped/*.jpg')\n",
+ " preprocess_cropped_imgs(allimg_cropped) # extract latent keypoints if necessary\n",
+ "\n",
+ " for ind in tqdm(range(len(all_wavs2))):\n",
+ " for img_path in tqdm(allimg_cropped):\n",
+ " audio_path = all_wavs2[ind]\n",
+ " # read in data\n",
+ " audio_frames, poseimgs, deep_feature, source_img, he_source, he_driving, num_frames, y_trg, z_trg, latent_path_driving = prepare_test_data(img_path, audio_path, config['model_params']['audio2kp_params'], emotype)\n",
+ "\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " source_img = torch.from_numpy(source_img).unsqueeze(0).cuda()\n",
+ " kp_canonical = kp_detector(source_img, with_feature=True) # {'value': value, 'jacobian': jacobian}\n",
+ " kp_cano = kp_canonical['value']\n",
+ "\n",
+ " x = {}\n",
+ " x['mel'] = audio_frames.unsqueeze(1).unsqueeze(0).cuda()\n",
+ " x['z_trg'] = z_trg.unsqueeze(0).cuda()\n",
+ " x['y_trg'] = torch.tensor(y_trg, dtype=torch.long).cuda().reshape(1)\n",
+ " x['pose'] = poseimgs.cuda()\n",
+ " x['deep'] = deep_feature.cuda().unsqueeze(0)\n",
+ " x['he_driving'] = {'yaw': torch.from_numpy(he_driving['yaw']).cuda().unsqueeze(0),\n",
+ " 'pitch': torch.from_numpy(he_driving['pitch']).cuda().unsqueeze(0),\n",
+ " 'roll': torch.from_numpy(he_driving['roll']).cuda().unsqueeze(0),\n",
+ " 't': torch.from_numpy(he_driving['t']).cuda().unsqueeze(0),\n",
+ " }\n",
+ "\n",
+ " ### emotion prompt\n",
+ " emoprompt, deepprompt = emotionprompt(x)\n",
+ " a2kp_exps = []\n",
+ " emo_exps = []\n",
+ " T = 5\n",
+ " if T == 1:\n",
+ " for i in range(x['mel'].shape[1]):\n",
+ " xi = {}\n",
+ " xi['mel'] = x['mel'][:,i,:,:,:].unsqueeze(1)\n",
+ " xi['z_trg'] = x['z_trg']\n",
+ " xi['y_trg'] = x['y_trg']\n",
+ " xi['pose'] = x['pose'][i,:,:,:,:].unsqueeze(0)\n",
+ " xi['deep'] = x['deep'][:,i,:,:,:].unsqueeze(1)\n",
+ " xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i,:].unsqueeze(0),\n",
+ " 'pitch': x['he_driving']['pitch'][:,i,:].unsqueeze(0),\n",
+ " 'roll': x['he_driving']['roll'][:,i,:].unsqueeze(0),\n",
+ " 't': x['he_driving']['t'][:,i,:].unsqueeze(0),\n",
+ " }\n",
+ " he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n",
+ " emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n",
+ " a2kp_exps.append(he_driving_emo_xi['emo'])\n",
+ " emo_exps.append(emo_exp)\n",
+ " elif T is not None:\n",
+ " for i in range(x['mel'].shape[1]//T+1):\n",
+ " if i*T >= x['mel'].shape[1]:\n",
+ " break\n",
+ " xi = {}\n",
+ " xi['mel'] = x['mel'][:,i*T:(i+1)*T,:,:,:]\n",
+ " xi['z_trg'] = x['z_trg']\n",
+ " xi['y_trg'] = x['y_trg']\n",
+ " xi['pose'] = x['pose'][i*T:(i+1)*T,:,:,:,:]\n",
+ " xi['deep'] = x['deep'][:,i*T:(i+1)*T,:,:,:]\n",
+ " xi['he_driving'] = {'yaw': x['he_driving']['yaw'][:,i*T:(i+1)*T,:],\n",
+ " 'pitch': x['he_driving']['pitch'][:,i*T:(i+1)*T,:],\n",
+ " 'roll': x['he_driving']['roll'][:,i*T:(i+1)*T,:],\n",
+ " 't': x['he_driving']['t'][:,i*T:(i+1)*T,:],\n",
+ " }\n",
+ " he_driving_emo_xi, input_st_xi = audio2kptransformer(xi, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n",
+ " emo_exp = sidetuning(input_st_xi, emoprompt, deepprompt)\n",
+ " a2kp_exps.append(he_driving_emo_xi['emo'])\n",
+ " emo_exps.append(emo_exp)\n",
+ "\n",
+ " if T is None:\n",
+ " he_driving_emo, input_st = audio2kptransformer(x, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n",
+ " emo_exps = sidetuning(input_st, emoprompt, deepprompt).reshape(-1, 45)\n",
+ " else:\n",
+ " he_driving_emo = {}\n",
+ " he_driving_emo['emo'] = torch.cat(a2kp_exps, dim=0)\n",
+ " emo_exps = torch.cat(emo_exps, dim=0).reshape(-1, 45)\n",
+ "\n",
+ " exp = he_driving_emo['emo']\n",
+ " device = exp.get_device()\n",
+ " exp = torch.mm(exp, expU.t().to(device))\n",
+ " exp = exp + expmean.expand_as(exp).to(device)\n",
+ " exp = exp + emo_exps\n",
+ "\n",
+ "\n",
+ " source_area = ConvexHull(kp_cano[0].cpu().numpy()).volume\n",
+ " exp = exp * source_area\n",
+ "\n",
+ " he_new_driving = {'yaw': torch.from_numpy(he_driving['yaw']).cuda(),\n",
+ " 'pitch': torch.from_numpy(he_driving['pitch']).cuda(),\n",
+ " 'roll': torch.from_numpy(he_driving['roll']).cuda(),\n",
+ " 't': torch.from_numpy(he_driving['t']).cuda(),\n",
+ " 'exp': exp}\n",
+ " he_driving['exp'] = torch.from_numpy(he_driving['exp']).cuda()\n",
+ "\n",
+ " kp_source = keypoint_transformation(kp_canonical, he_source, False)\n",
+ " mean_source = torch.mean(kp_source['value'], dim=1)[0]\n",
+ " kp_driving = keypoint_transformation(kp_canonical, he_new_driving, False)\n",
+ " mean_driving = torch.mean(torch.mean(kp_driving['value'], dim=1), dim=0)\n",
+ " kp_driving['value'] = kp_driving['value']+(mean_source-mean_driving).unsqueeze(0).unsqueeze(0)\n",
+ " bs = kp_source['value'].shape[0]\n",
+ " predictions_gen = []\n",
+ " for i in tqdm(range(num_frames)):\n",
+ " kp_si = {}\n",
+ " kp_si['value'] = kp_source['value'][0].unsqueeze(0)\n",
+ " kp_di = {}\n",
+ " kp_di['value'] = kp_driving['value'][i].unsqueeze(0)\n",
+ " generated = generator(source_img, kp_source=kp_si, kp_driving=kp_di, prompt=emoprompt)\n",
+ " predictions_gen.append(\n",
+ " (np.transpose(generated['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))\n",
+ "\n",
+ " log_dir = save_dir\n",
+ " os.makedirs(os.path.join(log_dir, \"temp\"), exist_ok=True)\n",
+ "\n",
+ " f_name = os.path.basename(img_path[:-4]) + \"_\" + emotype + \"_\" + os.path.basename(latent_path_driving)[:-4] + \".mp4\"\n",
+ " video_path = os.path.join(log_dir, \"temp\", f_name)\n",
+ " imageio.mimsave(video_path, predictions_gen, fps=25.0)\n",
+ "\n",
+ " save_video = os.path.join(log_dir, f_name)\n",
+ " cmd = r'ffmpeg -loglevel error -y -i \"%s\" -i \"%s\" -vcodec copy -shortest \"%s\"' % (video_path, audio_path, save_video)\n",
+ " os.system(cmd)\n",
+ " os.remove(video_path)\n",
+ "\n",
+ "if __name__ == '__main__':\n",
+ " argparser = argparse.ArgumentParser()\n",
+ " argparser.add_argument(\"--save_dir\", type=str, default=\" \", help=\"path of the output video\")\n",
+ " argparser.add_argument(\"--name\", type=str, default=\"deepprompt_eam3d_all_final_313\", help=\"path of the output video\")\n",
+ " argparser.add_argument(\"--emo\", type=str, default=\"hap\", help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n",
+ " argparser.add_argument(\"--root_wav\", type=str, default='./demo/video_processed/M003_neu_1_001', help=\"emotion type ('ang', 'con', 'dis', 'fea', 'hap', 'neu', 'sad', 'sur')\")\n",
+ " args = argparser.parse_args()\n",
+ "\n",
+ " root_wav=args.root_wav\n",
+ "\n",
+ " if len(args.name) > 1:\n",
+ " name = args.name\n",
+ " print(name)\n",
+ " test(f'/content/EAT_code/ckpt/deepprompt_eam3d_all_final_313.pth.tar', args.emo, save_dir=f'./demo/output/{name}/')\n",
+ "\n"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "cMz72QBbgRsc"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "CmeH0D2ayRrf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title trasnformer.py with fixed paths\n",
+ "import torch.nn as nn\n",
+ "import torch\n",
+ "from modules.util import mydownres2Dblock\n",
+ "import numpy as np\n",
+ "from modules.util import AntiAliasInterpolation2d, make_coordinate_grid\n",
+ "from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\n",
+ "import torch.nn.functional as F\n",
+ "import copy\n",
+ "\n",
+ "\n",
+ "class PositionalEncoding(nn.Module):\n",
+ "\n",
+ " def __init__(self, d_hid, n_position=200):\n",
+ " super(PositionalEncoding, self).__init__()\n",
+ "\n",
+ " # Not a parameter\n",
+ " self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))\n",
+ "\n",
+ " def _get_sinusoid_encoding_table(self, n_position, d_hid):\n",
+ " ''' Sinusoid position encoding table '''\n",
+ " # TODO: make it with torch instead of numpy\n",
+ "\n",
+ " def get_position_angle_vec(position):\n",
+ " return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]\n",
+ "\n",
+ " sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])\n",
+ " sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i\n",
+ " sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1\n",
+ "\n",
+ " return torch.FloatTensor(sinusoid_table).unsqueeze(0)\n",
+ "\n",
+ " def forward(self, winsize):\n",
+ " return self.pos_table[:, :winsize].clone().detach()\n",
+ "\n",
+ "def _get_activation_fn(activation):\n",
+ " \"\"\"Return an activation function given a string\"\"\"\n",
+ " if activation == \"relu\":\n",
+ " return F.relu\n",
+ " if activation == \"gelu\":\n",
+ " return F.gelu\n",
+ " if activation == \"glu\":\n",
+ " return F.glu\n",
+ " raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n",
+ "\n",
+ "def _get_clones(module, N):\n",
+ " return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n",
+ "\n",
+ "### light weight transformer encoder\n",
+ "class TransformerST(nn.Module):\n",
+ "\n",
+ " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n",
+ " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n",
+ " activation=\"relu\", normalize_before=False):\n",
+ " super().__init__()\n",
+ "\n",
+ " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n",
+ " dropout, activation, normalize_before)\n",
+ " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n",
+ " self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n",
+ "\n",
+ " self._reset_parameters()\n",
+ "\n",
+ " self.d_model = d_model\n",
+ " self.nhead = nhead\n",
+ "\n",
+ " def _reset_parameters(self):\n",
+ " for p in self.parameters():\n",
+ " if p.dim() > 1:\n",
+ " nn.init.xavier_uniform_(p)\n",
+ "\n",
+ " def forward(self, src, pos_embed):\n",
+ " # flatten NxCxHxW to HWxNxC\n",
+ "\n",
+ " src = src.permute(1, 0, 2)\n",
+ " pos_embed = pos_embed.permute(1, 0, 2)\n",
+ "\n",
+ " memory = self.encoder(src, pos=pos_embed)\n",
+ "\n",
+ " return memory\n",
+ "\n",
+ "class TransformerEncoder(nn.Module):\n",
+ "\n",
+ " def __init__(self, encoder_layer, num_layers, norm=None):\n",
+ " super().__init__()\n",
+ " self.layers = _get_clones(encoder_layer, num_layers)\n",
+ " self.num_layers = num_layers\n",
+ " self.norm = norm\n",
+ "\n",
+ " def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):\n",
+ " output = src+pos\n",
+ "\n",
+ " for layer in self.layers:\n",
+ " output = layer(output, src_mask=mask,\n",
+ " src_key_padding_mask=src_key_padding_mask, pos=pos)\n",
+ "\n",
+ " if self.norm is not None:\n",
+ " output = self.norm(output)\n",
+ "\n",
+ " return output\n",
+ "\n",
+ "class TransformerDecoder(nn.Module):\n",
+ "\n",
+ " def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n",
+ " super().__init__()\n",
+ " self.layers = _get_clones(decoder_layer, num_layers)\n",
+ " self.num_layers = num_layers\n",
+ " self.norm = norm\n",
+ " self.return_intermediate = return_intermediate\n",
+ "\n",
+ " def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n",
+ " memory_key_padding_mask = None,\n",
+ " pos = None,\n",
+ " query_pos = None):\n",
+ " output = tgt+pos+query_pos\n",
+ "\n",
+ " intermediate = []\n",
+ "\n",
+ " for layer in self.layers:\n",
+ " output = layer(output, memory, tgt_mask=tgt_mask,\n",
+ " memory_mask=memory_mask,\n",
+ " tgt_key_padding_mask=tgt_key_padding_mask,\n",
+ " memory_key_padding_mask=memory_key_padding_mask,\n",
+ " pos=pos, query_pos=query_pos)\n",
+ " if self.return_intermediate:\n",
+ " intermediate.append(self.norm(output))\n",
+ "\n",
+ " if self.norm is not None:\n",
+ " output = self.norm(output)\n",
+ " if self.return_intermediate:\n",
+ " intermediate.pop()\n",
+ " intermediate.append(output)\n",
+ "\n",
+ " if self.return_intermediate:\n",
+ " return torch.stack(intermediate)\n",
+ "\n",
+ " return output.unsqueeze(0)\n",
+ "\n",
+ "\n",
+ "class Transformer(nn.Module):\n",
+ "\n",
+ " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n",
+ " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n",
+ " activation=\"relu\", normalize_before=False,\n",
+ " return_intermediate_dec=True):\n",
+ " super().__init__()\n",
+ "\n",
+ " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n",
+ " dropout, activation, normalize_before)\n",
+ " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n",
+ " self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n",
+ "\n",
+ " decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n",
+ " dropout, activation, normalize_before)\n",
+ " decoder_norm = nn.LayerNorm(d_model)\n",
+ " self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,\n",
+ " return_intermediate=return_intermediate_dec)\n",
+ "\n",
+ " self._reset_parameters()\n",
+ "\n",
+ " self.d_model = d_model\n",
+ " self.nhead = nhead\n",
+ "\n",
+ " def _reset_parameters(self):\n",
+ " for p in self.parameters():\n",
+ " if p.dim() > 1:\n",
+ " nn.init.xavier_uniform_(p)\n",
+ "\n",
+ " def forward(self, src, query_embed, pos_embed):\n",
+ " # flatten NxCxHxW to HWxNxC\n",
+ "\n",
+ " src = src.permute(1, 0, 2)\n",
+ " pos_embed = pos_embed.permute(1, 0, 2)\n",
+ " query_embed = query_embed.permute(1, 0, 2)\n",
+ "\n",
+ " tgt = torch.zeros_like(query_embed)\n",
+ " memory = self.encoder(src, pos=pos_embed)\n",
+ "\n",
+ " hs = self.decoder(tgt, memory,\n",
+ " pos=pos_embed, query_pos=query_embed)\n",
+ " return hs, memory\n",
+ "\n",
+ "\n",
+ "class TransformerDeep(nn.Module):\n",
+ "\n",
+ " def __init__(self, d_model=128, nhead=8, num_encoder_layers=6,\n",
+ " num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,\n",
+ " activation=\"relu\", normalize_before=False,\n",
+ " return_intermediate_dec=True):\n",
+ " super().__init__()\n",
+ "\n",
+ " encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,\n",
+ " dropout, activation, normalize_before)\n",
+ " encoder_norm = nn.LayerNorm(d_model) if normalize_before else None\n",
+ " self.encoder = TransformerEncoderDeep(encoder_layer, num_encoder_layers, encoder_norm)\n",
+ "\n",
+ " decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,\n",
+ " dropout, activation, normalize_before)\n",
+ " decoder_norm = nn.LayerNorm(d_model)\n",
+ " self.decoder = TransformerDecoderDeep(decoder_layer, num_decoder_layers, decoder_norm,\n",
+ " return_intermediate=return_intermediate_dec)\n",
+ "\n",
+ " self._reset_parameters()\n",
+ "\n",
+ " self.d_model = d_model\n",
+ " self.nhead = nhead\n",
+ "\n",
+ " def _reset_parameters(self):\n",
+ " for p in self.parameters():\n",
+ " if p.dim() > 1:\n",
+ " nn.init.xavier_uniform_(p)\n",
+ "\n",
+ " def forward(self, src, query_embed, pos_embed, deepprompt):\n",
+ " # flatten NxCxHxW to HWxNxC\n",
+ "\n",
+ " # print('src before permute: ', src.shape) # 5, 12, 128\n",
+ " src = src.permute(1, 0, 2)\n",
+ " # print('src after permute: ', src.shape) # 12, 5, 128\n",
+ " pos_embed = pos_embed.permute(1, 0, 2)\n",
+ " query_embed = query_embed.permute(1, 0, 2)\n",
+ "\n",
+ " tgt = torch.zeros_like(query_embed) # actually is tgt + query_embed\n",
+ " memory = self.encoder(src, deepprompt, pos=pos_embed)\n",
+ "\n",
+ " hs = self.decoder(tgt, deepprompt, memory,\n",
+ " pos=pos_embed, query_pos=query_embed)\n",
+ " return hs, memory\n",
+ "\n",
+ "class TransformerEncoderDeep(nn.Module):\n",
+ "\n",
+ " def __init__(self, encoder_layer, num_layers, norm=None):\n",
+ " super().__init__()\n",
+ " self.layers = _get_clones(encoder_layer, num_layers)\n",
+ " self.num_layers = num_layers\n",
+ " self.norm = norm\n",
+ "\n",
+ " def forward(self, src, deepprompt, mask = None, src_key_padding_mask = None, pos = None):\n",
+ " # print('input: ', src.shape) # 12 5 128\n",
+ " # print('deepprompt:', deepprompt.shape) # 1 6 128\n",
+ " ### TODO: add deep prompt in encoder\n",
+ " bs = src.shape[1]\n",
+ " bbs = deepprompt.shape[0]\n",
+ " idx=0\n",
+ " emoprompt = deepprompt[:,idx,:]\n",
+ " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
+ " # print(emoprompt.shape) # 1 5 128\n",
+ " src = torch.cat([src, emoprompt], dim=0)\n",
+ " # print(src.shape) # 13 5 128\n",
+ " output = src+pos\n",
+ "\n",
+ " for layer in self.layers:\n",
+ " output = layer(output, src_mask=mask,\n",
+ " src_key_padding_mask=src_key_padding_mask, pos=pos)\n",
+ "\n",
+ " ### deep prompt\n",
+ " if idx+1 < len(self.layers):\n",
+ " idx = idx + 1\n",
+ " # print(idx)\n",
+ " emoprompt = deepprompt[:,idx,:]\n",
+ " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
+ " # print(output.shape) # 13 5 128\n",
+ " output = torch.cat([output[:-1], emoprompt], dim=0)\n",
+ " # print(output.shape) # 13 5 128\n",
+ "\n",
+ " if self.norm is not None:\n",
+ " output = self.norm(output)\n",
+ "\n",
+ " return output\n",
+ "\n",
+ "class TransformerDecoderDeep(nn.Module):\n",
+ "\n",
+ " def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):\n",
+ " super().__init__()\n",
+ " self.layers = _get_clones(decoder_layer, num_layers)\n",
+ " self.num_layers = num_layers\n",
+ " self.norm = norm\n",
+ " self.return_intermediate = return_intermediate\n",
+ "\n",
+ " def forward(self, tgt, deepprompt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,\n",
+ " memory_key_padding_mask = None,\n",
+ " pos = None,\n",
+ " query_pos = None):\n",
+ " # print('input: ', query_pos.shape) 12 5 128\n",
+ " ### TODO: add deep prompt in encoder\n",
+ " bs = query_pos.shape[1]\n",
+ " bbs = deepprompt.shape[0]\n",
+ " idx=0\n",
+ " emoprompt = deepprompt[:,idx,:]\n",
+ " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
+ " query_pos = torch.cat([query_pos, emoprompt], dim=0)\n",
+ " # print(query_pos.shape) # 13 5 128\n",
+ " # print(torch.sum(tgt)) # 0\n",
+ " output = pos+query_pos\n",
+ "\n",
+ " intermediate = []\n",
+ "\n",
+ " for layer in self.layers:\n",
+ " output = layer(output, memory, tgt_mask=tgt_mask,\n",
+ " memory_mask=memory_mask,\n",
+ " tgt_key_padding_mask=tgt_key_padding_mask,\n",
+ " memory_key_padding_mask=memory_key_padding_mask,\n",
+ " pos=pos, query_pos=query_pos)\n",
+ " if self.return_intermediate:\n",
+ " intermediate.append(self.norm(output))\n",
+ "\n",
+ " ### deep prompt\n",
+ " if idx+1 < len(self.layers):\n",
+ " idx = idx + 1\n",
+ " # print(idx)\n",
+ " emoprompt = deepprompt[:,idx,:]\n",
+ " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(0)\n",
+ " # print(output.shape) # 13 5 128\n",
+ " output = torch.cat([output[:-1], emoprompt], dim=0)\n",
+ " # print(output.shape) # 13 5 128\n",
+ "\n",
+ " if self.norm is not None:\n",
+ " output = self.norm(output)\n",
+ " if self.return_intermediate:\n",
+ " intermediate.pop()\n",
+ " intermediate.append(output)\n",
+ "\n",
+ " if self.return_intermediate:\n",
+ " return torch.stack(intermediate)\n",
+ "\n",
+ " return output.unsqueeze(0)\n",
+ "\n",
+ "\n",
+ "class TransformerEncoderLayer(nn.Module):\n",
+ "\n",
+ " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n",
+ " activation=\"relu\", normalize_before=False):\n",
+ " super().__init__()\n",
+ " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
+ " # Implementation of Feedforward model\n",
+ " self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
+ "\n",
+ " self.norm1 = nn.LayerNorm(d_model)\n",
+ " self.norm2 = nn.LayerNorm(d_model)\n",
+ " self.dropout1 = nn.Dropout(dropout)\n",
+ " self.dropout2 = nn.Dropout(dropout)\n",
+ "\n",
+ " self.activation = _get_activation_fn(activation)\n",
+ " self.normalize_before = normalize_before\n",
+ "\n",
+ " def with_pos_embed(self, tensor, pos):\n",
+ " return tensor if pos is None else tensor + pos\n",
+ "\n",
+ " def forward_post(self,\n",
+ " src,\n",
+ " src_mask = None,\n",
+ " src_key_padding_mask = None,\n",
+ " pos = None):\n",
+ " # q = k = self.with_pos_embed(src, pos)\n",
+ " src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " src = src + self.dropout1(src2)\n",
+ " src = self.norm1(src)\n",
+ " src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n",
+ " src = src + self.dropout2(src2)\n",
+ " src = self.norm2(src)\n",
+ " return src\n",
+ "\n",
+ " def forward_pre(self, src,\n",
+ " src_mask = None,\n",
+ " src_key_padding_mask = None,\n",
+ " pos = None):\n",
+ " src2 = self.norm1(src)\n",
+ " # q = k = self.with_pos_embed(src2, pos)\n",
+ " src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,\n",
+ " key_padding_mask=src_key_padding_mask)[0]\n",
+ " src = src + self.dropout1(src2)\n",
+ " src2 = self.norm2(src)\n",
+ " src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n",
+ " src = src + self.dropout2(src2)\n",
+ " return src\n",
+ "\n",
+ " def forward(self, src,\n",
+ " src_mask = None,\n",
+ " src_key_padding_mask = None,\n",
+ " pos = None):\n",
+ " if self.normalize_before:\n",
+ " return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n",
+ " return self.forward_post(src, src_mask, src_key_padding_mask, pos)\n",
+ "\n",
+ "class TransformerDecoderLayer(nn.Module):\n",
+ "\n",
+ " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n",
+ " activation=\"relu\", normalize_before=False):\n",
+ " super().__init__()\n",
+ " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
+ " self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
+ " # Implementation of Feedforward model\n",
+ " self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
+ "\n",
+ " self.norm1 = nn.LayerNorm(d_model)\n",
+ " self.norm2 = nn.LayerNorm(d_model)\n",
+ " self.norm3 = nn.LayerNorm(d_model)\n",
+ " self.dropout1 = nn.Dropout(dropout)\n",
+ " self.dropout2 = nn.Dropout(dropout)\n",
+ " self.dropout3 = nn.Dropout(dropout)\n",
+ "\n",
+ " self.activation = _get_activation_fn(activation)\n",
+ " self.normalize_before = normalize_before\n",
+ "\n",
+ " def with_pos_embed(self, tensor, pos):\n",
+ " return tensor if pos is None else tensor + pos\n",
+ "\n",
+ " def forward_post(self, tgt, memory,\n",
+ " tgt_mask = None,\n",
+ " memory_mask = None,\n",
+ " tgt_key_padding_mask = None,\n",
+ " memory_key_padding_mask = None,\n",
+ " pos = None,\n",
+ " query_pos = None):\n",
+ " # q = k = self.with_pos_embed(tgt, query_pos)\n",
+ " tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,\n",
+ " key_padding_mask=tgt_key_padding_mask)[0]\n",
+ " tgt = tgt + self.dropout1(tgt2)\n",
+ " tgt = self.norm1(tgt)\n",
+ " tgt2 = self.multihead_attn(query=tgt,\n",
+ " key=memory,\n",
+ " value=memory, attn_mask=memory_mask,\n",
+ " key_padding_mask=memory_key_padding_mask)[0]\n",
+ " tgt = tgt + self.dropout2(tgt2)\n",
+ " tgt = self.norm2(tgt)\n",
+ " tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n",
+ " tgt = tgt + self.dropout3(tgt2)\n",
+ " tgt = self.norm3(tgt)\n",
+ " return tgt\n",
+ "\n",
+ " def forward_pre(self, tgt, memory,\n",
+ " tgt_mask = None,\n",
+ " memory_mask = None,\n",
+ " tgt_key_padding_mask = None,\n",
+ " memory_key_padding_mask = None,\n",
+ " pos = None,\n",
+ " query_pos = None):\n",
+ " tgt2 = self.norm1(tgt)\n",
+ " # q = k = self.with_pos_embed(tgt2, query_pos)\n",
+ " tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,\n",
+ " key_padding_mask=tgt_key_padding_mask)[0]\n",
+ " tgt = tgt + self.dropout1(tgt2)\n",
+ " tgt2 = self.norm2(tgt)\n",
+ " tgt2 = self.multihead_attn(query=tgt2,\n",
+ " key=memory,\n",
+ " value=memory, attn_mask=memory_mask,\n",
+ " key_padding_mask=memory_key_padding_mask)[0]\n",
+ " tgt = tgt + self.dropout2(tgt2)\n",
+ " tgt2 = self.norm3(tgt)\n",
+ " tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n",
+ " tgt = tgt + self.dropout3(tgt2)\n",
+ " return tgt\n",
+ "\n",
+ " def forward(self, tgt, memory,\n",
+ " tgt_mask = None,\n",
+ " memory_mask = None,\n",
+ " tgt_key_padding_mask = None,\n",
+ " memory_key_padding_mask = None,\n",
+ " pos = None,\n",
+ " query_pos = None):\n",
+ " if self.normalize_before:\n",
+ " return self.forward_pre(tgt, memory, tgt_mask, memory_mask,\n",
+ " tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n",
+ " return self.forward_post(tgt, memory, tgt_mask, memory_mask,\n",
+ " tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)\n",
+ "\n",
+ "from Utils.JDC.model import JDCNet\n",
+ "from modules.audioencoder import AudioEncoder, MappingNetwork, StyleEncoder, AdaIN, EAModule\n",
+ "\n",
+ "def headpose_pred_to_degree(pred):\n",
+ " device = pred.device\n",
+ " idx_tensor = [idx for idx in range(66)]\n",
+ " idx_tensor = torch.FloatTensor(idx_tensor).to(device)\n",
+ " pred = F.softmax(pred, dim=1)\n",
+ " degree = torch.sum(pred*idx_tensor, axis=1)\n",
+ " # degree = F.one_hot(degree.to(torch.int64), num_classes=66)\n",
+ " return degree\n",
+ "\n",
+ "def get_rotation_matrix(yaw, pitch, roll):\n",
+ " yaw = yaw / 180 * 3.14\n",
+ " pitch = pitch / 180 * 3.14\n",
+ " roll = roll / 180 * 3.14\n",
+ "\n",
+ " roll = roll.unsqueeze(1)\n",
+ " pitch = pitch.unsqueeze(1)\n",
+ " yaw = yaw.unsqueeze(1)\n",
+ "\n",
+ " pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),\n",
+ " torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),\n",
+ " torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)\n",
+ "\n",
+ " pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)\n",
+ "\n",
+ " yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),\n",
+ " torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),\n",
+ " -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)\n",
+ " yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)\n",
+ "\n",
+ " roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),\n",
+ " torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),\n",
+ " torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)\n",
+ " roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)\n",
+ "\n",
+ " rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)\n",
+ "\n",
+ " return yaw, pitch, roll, yaw_mat.view(yaw_mat.shape[0], 9), pitch_mat.view(pitch_mat.shape[0], 9), roll_mat.view(roll_mat.shape[0], 9), rot_mat.view(rot_mat.shape[0], 9)\n",
+ "\n",
+ "class Audio2kpTransformerBBoxQDeep(nn.Module):\n",
+ " def __init__(self, embedding_dim, num_kp, num_w, face_adain=False):\n",
+ " super(Audio2kpTransformerBBoxQDeep, self).__init__()\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.num_kp = num_kp\n",
+ " self.num_w = num_w\n",
+ "\n",
+ "\n",
+ " self.embedding = nn.Embedding(41, embedding_dim)\n",
+ "\n",
+ " self.face_shrink = nn.Linear(240, 32)\n",
+ " self.hp_extractor = nn.Linear(45, 128)\n",
+ "\n",
+ " self.pos_enc = PositionalEncoding(128,20)\n",
+ " input_dim = 1\n",
+ "\n",
+ " self.decode_dim = 64\n",
+ " self.audio_embedding = nn.Sequential( # n x 29 x 16\n",
+ " nn.Conv1d(29, 32, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 32 x 8\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Conv1d(32, 32, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 32 x 4\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Conv1d(32, 64, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 64 x 2\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Conv1d(64, 64, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 64 x 1\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " )\n",
+ " self.encoder_fc1 = nn.Sequential(\n",
+ " nn.Linear(192, 128),\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Linear(128, 128),\n",
+ " )\n",
+ "\n",
+ " self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n",
+ " # nn.GroupNorm(4, 8, affine=True),\n",
+ " BatchNorm2d(8),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n",
+ "\n",
+ " self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n",
+ " # self.mappingnet = MappingNetwork(latent_dim=16, style_dim=128, num_domains=8, hidden_dim=512)\n",
+ " # self.stylenet = StyleEncoder(dim_in=64, style_dim=64, num_domains=8, max_conv_dim=512)\n",
+ " self.face_adain = face_adain\n",
+ " if self.face_adain:\n",
+ " self.fadain = AdaIN(style_dim=128, num_features=32)\n",
+ " # norm = 'layer_2d' #\n",
+ " norm = 'batch'\n",
+ "\n",
+ " self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n",
+ " mydownres2Dblock(32,48, normalize = norm),\n",
+ " mydownres2Dblock(48,64, normalize = norm),\n",
+ " mydownres2Dblock(64,96, normalize = norm),\n",
+ " mydownres2Dblock(96,128, normalize = norm),\n",
+ " nn.AvgPool2d(2))\n",
+ "\n",
+ " self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n",
+ " mydownres2Dblock(32,64),\n",
+ " mydownres2Dblock(64,128),\n",
+ " mydownres2Dblock(128,128),\n",
+ " mydownres2Dblock(128,128),\n",
+ " nn.AvgPool2d(2))\n",
+ " self.transformer = Transformer()\n",
+ " self.kp = nn.Linear(128, 32)\n",
+ "\n",
+ " # for m in self.modules():\n",
+ " # if isinstance(m, nn.Conv2d):\n",
+ " # # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
+ " # nn.init.xavier_uniform_(m.weight, gain=1)\n",
+ "\n",
+ " # if isinstance(m, nn.Linear):\n",
+ " # # trunc_normal_(m.weight, std=.03)\n",
+ " # nn.init.xavier_uniform_(m.weight, gain=1)\n",
+ " # if isinstance(m, nn.Linear) and m.bias is not None:\n",
+ " # nn.init.constant_(m.bias, 0)\n",
+ "\n",
+ " F0_path = './Utils/JDC/bst.t7'\n",
+ " F0_model = JDCNet(num_class=1, seq_len=32)\n",
+ " params = torch.load(F0_path, map_location='cpu')['net']\n",
+ " F0_model.load_state_dict(params)\n",
+ " self.f0_model = F0_model\n",
+ "\n",
+ " def rotation_and_translation(self, headpose, bbs, bs):\n",
+ " # print(headpose['roll'].shape, headpose['yaw'].shape, headpose['pitch'].shape, headpose['t'].shape)\n",
+ "\n",
+ " yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n",
+ " pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n",
+ " roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n",
+ " yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n",
+ " t = headpose['t'].reshape(bbs*bs, -1)\n",
+ " # hp = torch.cat([yaw, pitch, roll, yaw_v, pitch_v, roll_v, t], dim=1)\n",
+ " hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n",
+ " # hp = torch.cat([yaw, pitch, roll, torch.sin(yaw), torch.sin(pitch), torch.sin(roll), torch.cos(yaw), torch.cos(pitch), torch.cos(roll), t], dim=1)\n",
+ " return hp\n",
+ "\n",
+ " def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, hp=None, side=False):\n",
+ " bbs, bs, seqlen, _, _ = x['deep'].shape\n",
+ " # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n",
+ " if hp is None:\n",
+ " hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n",
+ " hp = self.hp_extractor(hp)\n",
+ "\n",
+ " pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n",
+ " # pose_feature = self.down_pose(pose).contiguous()\n",
+ " ### phoneme input feature\n",
+ " # phoneme_embedding = self.embedding(ph.long())\n",
+ " # phoneme_embedding = phoneme_embedding.reshape(bbs*bs*seqlen, 1, 16, 16)\n",
+ " # phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4)\n",
+ " # input_feature = torch.cat((pose_feature, phoneme_embedding), dim=1)\n",
+ " # print('input_feature: ', input_feature.shape)\n",
+ " # input_feature = phoneme_embedding\n",
+ "\n",
+ " audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n",
+ " deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n",
+ " # print(deep_feature.shape)\n",
+ "\n",
+ " input_feature = pose_feature\n",
+ " # print(input_feature.shape)\n",
+ " # assert(0)\n",
+ " input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n",
+ " input_feature = torch.cat([input_feature, deep_feature], dim=1)\n",
+ " input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n",
+ " # phoneme_embedding = self.phoneme_shrink(phoneme_embedding.squeeze(1))# 24*11, 128\n",
+ " input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n",
+ "\n",
+ " ### decode audio feature\n",
+ " ### use iteration to avoid batchnorm2d in different audio sequence\n",
+ " decoder_features = []\n",
+ " for i in range(bbs):\n",
+ " F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n",
+ " if emoprompt is None:\n",
+ " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n",
+ " else:\n",
+ " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n",
+ " audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n",
+ " decoder_feature = self.audio_embedding2(audio2)\n",
+ "\n",
+ " # decoder_feature = torch.cat([decoder_feature, audio2], dim=1)\n",
+ " # decoder_feature = F.interpolate(decoder_feature, scale_factor=2)# ([264, 35, 64, 64])\n",
+ " face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n",
+ " feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n",
+ " if self.face_adain:\n",
+ " feature_map = self.fadain(feature_map, emoprompt)\n",
+ " decoder_feature = self.decodefeature_extract(torch.cat(\n",
+ " (decoder_feature,\n",
+ " feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n",
+ " dim=1)).reshape(bs, seqlen, 128)\n",
+ " decoder_features.append(decoder_feature)\n",
+ " decoder_feature = torch.cat(decoder_features, dim=0)\n",
+ "\n",
+ " decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n",
+ "\n",
+ " # decoder_feature = torch.cat([decoder_feature], dim=1)\n",
+ "\n",
+ " # a2kp transformer\n",
+ " # position embedding\n",
+ " if emoprompt is None:\n",
+ " posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n",
+ " else:\n",
+ " posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + emotion prompt\n",
+ " emoprompt = emoprompt.unsqueeze(1).tile(1, bs, 1).reshape(bbs*bs, 128).unsqueeze(1)\n",
+ " input_feature = torch.cat([input_feature, emoprompt], dim=1)\n",
+ " decoder_feature = torch.cat([decoder_feature, emoprompt], dim=1)\n",
+ " out = {}\n",
+ " output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, )\n",
+ " output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n",
+ " out[\"emo\"] = self.kp(output_feature)\n",
+ " if side:\n",
+ " input_st = {}\n",
+ " input_st['hp'] = hp\n",
+ " input_st['decoder_feature'] = decoder_feature\n",
+ " input_st['memory'] = memory\n",
+ " return out, input_st\n",
+ " else:\n",
+ " return out\n",
+ "\n",
+ "\n",
+ "class Audio2kpTransformerBBoxQDeepPrompt(nn.Module):\n",
+ " def __init__(self, embedding_dim, num_kp, num_w, face_ea=False):\n",
+ " super(Audio2kpTransformerBBoxQDeepPrompt, self).__init__()\n",
+ " self.embedding_dim = embedding_dim\n",
+ " self.num_kp = num_kp\n",
+ " self.num_w = num_w\n",
+ "\n",
+ "\n",
+ " self.embedding = nn.Embedding(41, embedding_dim)\n",
+ "\n",
+ " self.face_shrink = nn.Linear(240, 32)\n",
+ " self.hp_extractor = nn.Linear(45, 128)\n",
+ "\n",
+ " self.pos_enc = PositionalEncoding(128,20)\n",
+ " input_dim = 1\n",
+ "\n",
+ " self.decode_dim = 64\n",
+ " self.audio_embedding = nn.Sequential( # n x 29 x 16\n",
+ " nn.Conv1d(29, 32, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 32 x 8\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Conv1d(32, 32, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 32 x 4\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Conv1d(32, 64, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 64 x 2\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Conv1d(64, 64, kernel_size=3, stride=2,\n",
+ " padding=1, bias=True), # n x 64 x 1\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " )\n",
+ " self.encoder_fc1 = nn.Sequential(\n",
+ " nn.Linear(192, 128),\n",
+ " nn.LeakyReLU(0.02, True),\n",
+ " nn.Linear(128, 128),\n",
+ " )\n",
+ "\n",
+ " self.audio_embedding2 = nn.Sequential(nn.Conv2d(1, 8, (3, 17), stride=(1, 1), padding=(1, 0)),\n",
+ " # nn.GroupNorm(4, 8, affine=True),\n",
+ " BatchNorm2d(8),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Conv2d(8, 32, (13, 13), stride=(1, 1), padding=(6, 6)))\n",
+ "\n",
+ " self.audioencoder = AudioEncoder(dim_in=64, style_dim=128, max_conv_dim=512, w_hpf=0, F0_channel=256)\n",
+ " self.face_ea = face_ea\n",
+ " if self.face_ea:\n",
+ " self.fea = EAModule(style_dim=128, num_features=32)\n",
+ " norm = 'batch'\n",
+ "\n",
+ " self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32, normalize = norm),\n",
+ " mydownres2Dblock(32,48, normalize = norm),\n",
+ " mydownres2Dblock(48,64, normalize = norm),\n",
+ " mydownres2Dblock(64,96, normalize = norm),\n",
+ " mydownres2Dblock(96,128, normalize = norm),\n",
+ " nn.AvgPool2d(2))\n",
+ "\n",
+ " self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32),\n",
+ " mydownres2Dblock(32,64),\n",
+ " mydownres2Dblock(64,128),\n",
+ " mydownres2Dblock(128,128),\n",
+ " mydownres2Dblock(128,128),\n",
+ " nn.AvgPool2d(2))\n",
+ " self.transformer = TransformerDeep()\n",
+ " self.kp = nn.Linear(128, 32)\n",
+ "\n",
+ " F0_path = '/content/EAT_code/Utils/JDC/bst.t7'\n",
+ " F0_model = JDCNet(num_class=1, seq_len=32)\n",
+ " params = torch.load(F0_path, map_location='cpu')['net']\n",
+ " F0_model.load_state_dict(params)\n",
+ " self.f0_model = F0_model\n",
+ "\n",
+ " def rotation_and_translation(self, headpose, bbs, bs):\n",
+ " yaw = headpose_pred_to_degree(headpose['yaw'].reshape(bbs*bs, -1))\n",
+ " pitch = headpose_pred_to_degree(headpose['pitch'].reshape(bbs*bs, -1))\n",
+ " roll = headpose_pred_to_degree(headpose['roll'].reshape(bbs*bs, -1))\n",
+ " yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v = get_rotation_matrix(yaw, pitch, roll)\n",
+ " t = headpose['t'].reshape(bbs*bs, -1)\n",
+ " hp = torch.cat([yaw.unsqueeze(1), pitch.unsqueeze(1), roll.unsqueeze(1), yaw_2, pitch_2, roll_2, yaw_v, pitch_v, roll_v, rot_v, t], dim=1)\n",
+ " return hp\n",
+ "\n",
+ " def forward(self, x, initial_kp = None, return_strg=False, emoprompt=None, deepprompt=None, hp=None, side=False):\n",
+ " bbs, bs, seqlen, _, _ = x['deep'].shape\n",
+ " # ph = x[\"pho\"].reshape(bbs*bs*seqlen, 1)\n",
+ " if hp is None:\n",
+ " hp = self.rotation_and_translation(x['he_driving'], bbs, bs)\n",
+ " hp = self.hp_extractor(hp)\n",
+ "\n",
+ " pose_feature = x[\"pose\"].reshape(bbs*bs*seqlen,1,64,64)\n",
+ "\n",
+ " audio = x['deep'].reshape(bbs*bs*seqlen, 16, 29).permute(0, 2, 1)\n",
+ " deep_feature = self.audio_embedding(audio).squeeze(-1)# ([264, 32, 16, 16])\n",
+ "\n",
+ " input_feature = pose_feature\n",
+ " input_feature = self.feature_extract(input_feature).reshape(bbs*bs*seqlen, 128)\n",
+ " input_feature = torch.cat([input_feature, deep_feature], dim=1)\n",
+ " input_feature = self.encoder_fc1(input_feature).reshape(bbs*bs, seqlen, 128)\n",
+ " input_feature = torch.cat([input_feature, hp.unsqueeze(1)], dim=1)\n",
+ "\n",
+ " ### decode audio feature\n",
+ " ### use iteration to avoid batchnorm2d in different audio sequence\n",
+ " decoder_features = []\n",
+ " for i in range(bbs):\n",
+ " F0 = self.f0_model.get_feature_GAN(x['mel'][i].reshape(bs, 1, 80, seqlen))\n",
+ " if emoprompt is None:\n",
+ " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=None, masks=None, F0=F0))\n",
+ " else:\n",
+ " audio_feature = (self.audioencoder(x['mel'][i].reshape(bs, 1, 80, seqlen), s=emoprompt[i].unsqueeze(0), masks=None, F0=F0))\n",
+ " audio2 = torch.permute(audio_feature, (0, 3, 1, 2)).reshape(bs*seqlen, 1, 64, 80)\n",
+ " decoder_feature = self.audio_embedding2(audio2)\n",
+ "\n",
+ " face_map = initial_kp[\"prediction_map\"][i].reshape(15*16, 64*64).permute(1, 0).reshape(64*64, 15*16)\n",
+ " face_feature_map = self.face_shrink(face_map).permute(1, 0).reshape(1, 32, 64, 64)\n",
+ " if self.face_ea:\n",
+ " face_feature_map = self.fea(face_feature_map, emoprompt)\n",
+ " decoder_feature = self.decodefeature_extract(torch.cat(\n",
+ " (decoder_feature,\n",
+ " face_feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),\n",
+ " dim=1)).reshape(bs, seqlen, 128)\n",
+ " decoder_features.append(decoder_feature)\n",
+ " decoder_feature = torch.cat(decoder_features, dim=0)\n",
+ "\n",
+ " decoder_feature = torch.cat([decoder_feature, hp.unsqueeze(1)], dim=1)\n",
+ "\n",
+ " # a2kp transformer\n",
+ " # position embedding\n",
+ " if emoprompt is None:\n",
+ " posi_em = self.pos_enc(self.num_w*2+1+1) # 11 + headpose token\n",
+ " else:\n",
+ " posi_em = self.pos_enc(self.num_w*2+1+1+1) # 11 + headpose token + deep emotion prompt\n",
+ " out = {}\n",
+ " output_feature, memory = self.transformer(input_feature, decoder_feature, posi_em, deepprompt)\n",
+ " output_feature = output_feature[-1, self.num_w] # returned intermediate output [6, 13, bbs*bs, 128]\n",
+ " out[\"emo\"] = self.kp(output_feature)\n",
+ " if side:\n",
+ " input_st = {}\n",
+ " input_st['hp'] = hp\n",
+ " input_st['face_feature_map'] = face_feature_map\n",
+ " input_st['bs'] = bs\n",
+ " input_st['bbs'] = bbs\n",
+ " return out, input_st\n",
+ " else:\n",
+ " return out\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "DZwVMAPDgZvO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title deepspeechfeatures.py fixed\n",
+ "\n",
+ "\"\"\"\n",
+ " DeepSpeech features processing routines.\n",
+ " NB: Based on VOCA code. See the corresponding license restrictions.\n",
+ "\"\"\"\n",
+ "\n",
+ "__all__ = ['conv_audios_to_deepspeech']\n",
+ "\n",
+ "import numpy as np\n",
+ "import warnings\n",
+ "import resampy\n",
+ "from scipy.io import wavfile\n",
+ "from python_speech_features import mfcc\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "\n",
+ "def conv_audios_to_deepspeech(audios,\n",
+ " out_files,\n",
+ " num_frames_info,\n",
+ " deepspeech_pb_path,\n",
+ " audio_window_size=16,\n",
+ " audio_window_stride=1):\n",
+ " \"\"\"\n",
+ " Convert list of audio files into files with DeepSpeech features.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " audios : list of str or list of None\n",
+ " Paths to input audio files.\n",
+ " out_files : list of str\n",
+ " Paths to output files with DeepSpeech features.\n",
+ " num_frames_info : list of int\n",
+ " List of numbers of frames.\n",
+ " deepspeech_pb_path : str\n",
+ " Path to DeepSpeech 0.1.0 frozen model.\n",
+ " audio_window_size : int, default 16\n",
+ " Audio window size.\n",
+ " audio_window_stride : int, default 1\n",
+ " Audio window stride.\n",
+ " \"\"\"\n",
+ " graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(deepspeech_pb_path)\n",
+ "\n",
+ " with tf.compat.v1.Session(graph=graph) as sess:\n",
+ " for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):\n",
+ " audio_sample_rate, audio = wavfile.read(audio_file_path)\n",
+ " if audio.ndim != 1:\n",
+ " warnings.warn(\"Audio has multiple channels, the first channel is used\")\n",
+ " audio = audio[:, 0]\n",
+ " ds_features = pure_conv_audio_to_deepspeech(\n",
+ " audio=audio,\n",
+ " audio_sample_rate=audio_sample_rate,\n",
+ " audio_window_size=audio_window_size,\n",
+ " audio_window_stride=audio_window_stride,\n",
+ " num_frames=num_frames,\n",
+ " net_fn=lambda x: sess.run(\n",
+ " logits_ph,\n",
+ " feed_dict={\n",
+ " input_node_ph: x[np.newaxis, ...],\n",
+ " input_lengths_ph: [x.shape[0]]}))\n",
+ " np.save(out_file_path, ds_features)\n",
+ "\n",
+ "\n",
+ "# data_util/deepspeech_features/deepspeech_features.py\n",
+ "def prepare_deepspeech_net(deepspeech_pb_path):\n",
+ " # Load graph and place_holders:\n",
+ " with tf.io.gfile.GFile(deepspeech_pb_path, \"rb\") as f:\n",
+ " graph_def = tf.compat.v1.GraphDef()\n",
+ " graph_def.ParseFromString(f.read())\n",
+ "\n",
+ " graph = tf.compat.v1.get_default_graph()\n",
+ "\n",
+ " tf.import_graph_def(graph_def, name=\"deepspeech\")\n",
+ " # check all graphs\n",
+ " # print('~'*50, [tensor for tensor in graph._nodes_by_name], '~'*50)\n",
+ " # print('~'*50, [tensor.name for tensor in graph.get_operations()], '~'*50)\n",
+ " # i modified\n",
+ " logits_ph = graph.get_tensor_by_name(\"logits:0\")\n",
+ " input_node_ph = graph.get_tensor_by_name(\"input_node:0\")\n",
+ " input_lengths_ph = graph.get_tensor_by_name(\"input_lengths:0\")\n",
+ " # original\n",
+ " # logits_ph = graph.get_tensor_by_name(\"deepspeech/logits:0\")\n",
+ " # input_node_ph = graph.get_tensor_by_name(\"deepspeech/input_node:0\")\n",
+ " # input_lengths_ph = graph.get_tensor_by_name(\"deepspeech/input_lengths:0\")\n",
+ "\n",
+ " return graph, logits_ph, input_node_ph, input_lengths_ph\n",
+ "\n",
+ "\n",
+ "def pure_conv_audio_to_deepspeech(audio,\n",
+ " audio_sample_rate,\n",
+ " audio_window_size,\n",
+ " audio_window_stride,\n",
+ " num_frames,\n",
+ " net_fn):\n",
+ " \"\"\"\n",
+ " Core routine for converting audion into DeepSpeech features.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " audio : np.array\n",
+ " Audio data.\n",
+ " audio_sample_rate : int\n",
+ " Audio sample rate.\n",
+ " audio_window_size : int\n",
+ " Audio window size.\n",
+ " audio_window_stride : int\n",
+ " Audio window stride.\n",
+ " num_frames : int or None\n",
+ " Numbers of frames.\n",
+ " net_fn : func\n",
+ " Function for DeepSpeech model call.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " np.array\n",
+ " DeepSpeech features.\n",
+ " \"\"\"\n",
+ " target_sample_rate = 16000\n",
+ " if audio_sample_rate != target_sample_rate:\n",
+ " resampled_audio = resampy.resample(\n",
+ " x=audio.astype(np.float),\n",
+ " sr_orig=audio_sample_rate,\n",
+ " sr_new=target_sample_rate)\n",
+ " else:\n",
+ " resampled_audio = audio.astype(np.float32)\n",
+ " input_vector = conv_audio_to_deepspeech_input_vector(\n",
+ " audio=resampled_audio.astype(np.int16),\n",
+ " sample_rate=target_sample_rate,\n",
+ " num_cepstrum=26,\n",
+ " num_context=9)\n",
+ "\n",
+ " network_output = net_fn(input_vector)\n",
+ "\n",
+ " deepspeech_fps = 50\n",
+ " video_fps = 60\n",
+ " audio_len_s = float(audio.shape[0]) / audio_sample_rate\n",
+ " if num_frames is None:\n",
+ " num_frames = int(round(audio_len_s * video_fps))\n",
+ " else:\n",
+ " video_fps = num_frames / audio_len_s\n",
+ " network_output = interpolate_features(\n",
+ " features=network_output[:, 0],\n",
+ " input_rate=deepspeech_fps,\n",
+ " output_rate=video_fps,\n",
+ " output_len=num_frames)\n",
+ "\n",
+ " # Make windows:\n",
+ " zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))\n",
+ " network_output = np.concatenate((zero_pad, network_output, zero_pad), axis=0)\n",
+ " windows = []\n",
+ " for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):\n",
+ " windows.append(network_output[window_index:window_index + audio_window_size])\n",
+ "\n",
+ " return np.array(windows)\n",
+ "\n",
+ "\n",
+ "def conv_audio_to_deepspeech_input_vector(audio,\n",
+ " sample_rate,\n",
+ " num_cepstrum,\n",
+ " num_context):\n",
+ " \"\"\"\n",
+ " Convert audio raw data into DeepSpeech input vector.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " audio : np.array\n",
+ " Audio data.\n",
+ " audio_sample_rate : int\n",
+ " Audio sample rate.\n",
+ " num_cepstrum : int\n",
+ " Number of cepstrum.\n",
+ " num_context : int\n",
+ " Number of context.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " np.array\n",
+ " DeepSpeech input vector.\n",
+ " \"\"\"\n",
+ " # Get mfcc coefficients:\n",
+ " features = mfcc(\n",
+ " signal=audio,\n",
+ " samplerate=sample_rate,\n",
+ " numcep=num_cepstrum)\n",
+ "\n",
+ " # We only keep every second feature (BiRNN stride = 2):\n",
+ " features = features[::2]\n",
+ "\n",
+ " # One stride per time step in the input:\n",
+ " num_strides = len(features)\n",
+ "\n",
+ " # Add empty initial and final contexts:\n",
+ " empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)\n",
+ " features = np.concatenate((empty_context, features, empty_context))\n",
+ "\n",
+ " # Create a view into the array with overlapping strides of size\n",
+ " # numcontext (past) + 1 (present) + numcontext (future):\n",
+ " window_size = 2 * num_context + 1\n",
+ " train_inputs = np.lib.stride_tricks.as_strided(\n",
+ " features,\n",
+ " shape=(num_strides, window_size, num_cepstrum),\n",
+ " strides=(features.strides[0], features.strides[0], features.strides[1]),\n",
+ " writeable=False)\n",
+ "\n",
+ " # Flatten the second and third dimensions:\n",
+ " train_inputs = np.reshape(train_inputs, [num_strides, -1])\n",
+ "\n",
+ " train_inputs = np.copy(train_inputs)\n",
+ " train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs)\n",
+ "\n",
+ " return train_inputs\n",
+ "\n",
+ "\n",
+ "def interpolate_features(features,\n",
+ " input_rate,\n",
+ " output_rate,\n",
+ " output_len):\n",
+ " \"\"\"\n",
+ " Interpolate DeepSpeech features.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " features : np.array\n",
+ " DeepSpeech features.\n",
+ " input_rate : int\n",
+ " input rate (FPS).\n",
+ " output_rate : int\n",
+ " Output rate (FPS).\n",
+ " output_len : int\n",
+ " Output data length.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " np.array\n",
+ " Interpolated data.\n",
+ " \"\"\"\n",
+ " input_len = features.shape[0]\n",
+ " num_features = features.shape[1]\n",
+ " input_timestamps = np.arange(input_len) / float(input_rate)\n",
+ " output_timestamps = np.arange(output_len) / float(output_rate)\n",
+ " output_features = np.zeros((output_len, num_features))\n",
+ " for feature_idx in range(num_features):\n",
+ " output_features[:, feature_idx] = np.interp(\n",
+ " x=output_timestamps,\n",
+ " xp=input_timestamps,\n",
+ " fp=features[:, feature_idx])\n",
+ " return output_features\n"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "Jq0kqup1Ogqd"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title deepspeechstore fixed\n",
+ "\n",
+ "\"\"\"\n",
+ " Routines for loading DeepSpeech model.\n",
+ "\"\"\"\n",
+ "\n",
+ "__all__ = ['get_deepspeech_model_file']\n",
+ "\n",
+ "import os\n",
+ "import zipfile\n",
+ "import logging\n",
+ "import hashlib\n",
+ "\n",
+ "\n",
+ "deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'\n",
+ "\n",
+ "\n",
+ "def get_deepspeech_model_file(local_model_store_dir_path=os.path.join(\"~\", \"/content/EAT_code/tensorflow\", \"models\")):\n",
+ " \"\"\"\n",
+ " Return location for the pretrained on local file system. This function will download from online model zoo when\n",
+ " model cannot be found or has mismatch. The root directory will be created if it doesn't exist.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " local_model_store_dir_path : str, default $TENSORFLOW_HOME/models\n",
+ " Location for keeping the model parameters.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_path\n",
+ " Path to the requested pretrained model file.\n",
+ " \"\"\"\n",
+ " sha1_hash = \"b90017e816572ddce84f5843f1fa21e6a377975e\"\n",
+ " file_name = \"deepspeech-0_1_0-b90017e8.pb\"\n",
+ " local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)\n",
+ " file_path = os.path.join(local_model_store_dir_path, file_name)\n",
+ " if os.path.exists(file_path):\n",
+ " if _check_sha1(file_path, sha1_hash):\n",
+ " return file_path\n",
+ " else:\n",
+ " logging.warning(\"Mismatch in the content of model file detected. Downloading again.\")\n",
+ " else:\n",
+ " logging.info(\"Model file not found. Downloading to {}.\".format(file_path))\n",
+ "\n",
+ " if not os.path.exists(local_model_store_dir_path):\n",
+ " os.makedirs(local_model_store_dir_path)\n",
+ "\n",
+ " zip_file_path = file_path + \".zip\"\n",
+ " _download(\n",
+ " url=\"{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip\".format(\n",
+ " repo_url=deepspeech_features_repo_url,\n",
+ " repo_release_tag=\"v0.0.1\",\n",
+ " file_name=file_name),\n",
+ " path=zip_file_path,\n",
+ " overwrite=False)\n",
+ " with zipfile.ZipFile(zip_file_path) as zf:\n",
+ " zf.extractall(local_model_store_dir_path)\n",
+ " os.remove(zip_file_path)\n",
+ "\n",
+ " if _check_sha1(file_path, sha1_hash):\n",
+ " return file_path\n",
+ " else:\n",
+ " raise ValueError(\"Downloaded file has different hash. Please try again.\")\n",
+ "\n",
+ "\n",
+ "def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):\n",
+ " \"\"\"\n",
+ " Download an given URL\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " url : str\n",
+ " URL to download\n",
+ " path : str, optional\n",
+ " Destination path to store downloaded file. By default stores to the\n",
+ " current directory with same name as in url.\n",
+ " overwrite : bool, optional\n",
+ " Whether to overwrite destination file if already exists.\n",
+ " sha1_hash : str, optional\n",
+ " Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified\n",
+ " but doesn't match.\n",
+ " retries : integer, default 5\n",
+ " The number of times to attempt the download in case of failure or non 200 return codes\n",
+ " verify_ssl : bool, default True\n",
+ " Verify SSL certificates.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " str\n",
+ " The file path of the downloaded file.\n",
+ " \"\"\"\n",
+ " import warnings\n",
+ " try:\n",
+ " import requests\n",
+ " except ImportError:\n",
+ " class requests_failed_to_import(object):\n",
+ " pass\n",
+ " requests = requests_failed_to_import\n",
+ "\n",
+ " if path is None:\n",
+ " fname = url.split(\"/\")[-1]\n",
+ " # Empty filenames are invalid\n",
+ " assert fname, \"Can't construct file-name from this URL. Please set the `path` option manually.\"\n",
+ " else:\n",
+ " path = os.path.expanduser(path)\n",
+ " if os.path.isdir(path):\n",
+ " fname = os.path.join(path, url.split(\"/\")[-1])\n",
+ " else:\n",
+ " fname = path\n",
+ " assert retries >= 0, \"Number of retries should be at least 0\"\n",
+ "\n",
+ " if not verify_ssl:\n",
+ " warnings.warn(\n",
+ " \"Unverified HTTPS request is being made (verify_ssl=False). \"\n",
+ " \"Adding certificate verification is strongly advised.\")\n",
+ "\n",
+ " if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):\n",
+ " dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))\n",
+ " if not os.path.exists(dirname):\n",
+ " os.makedirs(dirname)\n",
+ " while retries + 1 > 0:\n",
+ " # Disable pyling too broad Exception\n",
+ " # pylint: disable=W0703\n",
+ " try:\n",
+ " print(\"Downloading {} from {}...\".format(fname, url))\n",
+ " #r = requests.get(url, stream=True, verify=verify_ssl)\n",
+ " #if r.status_code != 200:\n",
+ " # raise RuntimeError(\"Failed downloading url {}\".format(url))\n",
+ " #with open(fname, \"wb\") as f:\n",
+ " # for chunk in r.iter_content(chunk_size=1024):\n",
+ " # if chunk: # filter out keep-alive new chunks\n",
+ " # f.write(chunk)\n",
+ " if sha1_hash and not _check_sha1(fname, sha1_hash):\n",
+ " raise UserWarning(\"File {} is downloaded but the content hash does not match.\"\n",
+ " \" The repo may be outdated or download may be incomplete. \"\n",
+ " \"If the `repo_url` is overridden, consider switching to \"\n",
+ " \"the default repo.\".format(fname))\n",
+ " break\n",
+ " except Exception as e:\n",
+ " retries -= 1\n",
+ " if retries <= 0:\n",
+ " raise e\n",
+ " else:\n",
+ " print(\"download failed, retrying, {} attempt{} left\"\n",
+ " .format(retries, \"s\" if retries > 1 else \"\"))\n",
+ "\n",
+ " return fname\n",
+ "\n",
+ "\n",
+ "def _check_sha1(filename, sha1_hash):\n",
+ " \"\"\"\n",
+ " Check whether the sha1 hash of the file content matches the expected hash.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " filename : str\n",
+ " Path to the file.\n",
+ " sha1_hash : str\n",
+ " Expected sha1 hash in hexadecimal digits.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " bool\n",
+ " Whether the file content matches the expected hash.\n",
+ " \"\"\"\n",
+ " sha1 = hashlib.sha1()\n",
+ " with open(filename, \"rb\") as f:\n",
+ " while True:\n",
+ " data = f.read(1048576)\n",
+ " if not data:\n",
+ " break\n",
+ " sha1.update(data)\n",
+ "\n",
+ " return sha1.hexdigest() == sha1_hash\n"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "BaZ1iwyuO_bl"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!zip -r EAT_code /content/EAT_code"
+ ],
+ "metadata": {
+ "id": "wAFm27NITV_C"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file