{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup & Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting requirements.txt\n"
     ]
    }
   ],
   "source": [
    "%%writefile requirements.txt\n",
    "torchaudio==0.11.*\n",
    "git+https://github.com/philschmid/pyannote-audio.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create Custom Handler for Inference Endpoints\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting handler.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile handler.py\n",
    "from typing import  Dict\n",
    "from pyannote.audio import Pipeline\n",
    "from transformers.pipelines.audio_utils import ffmpeg_read\n",
    "import torch \n",
    "\n",
    "SAMPLE_RATE = 16000\n",
    "\n",
    "\n",
    "\n",
    "class EndpointHandler():\n",
    "    def __init__(self, path=\"\"):\n",
    "        # load the model\n",
    "        self.pipeline = Pipeline.from_pretrained(\"pyannote/speaker-diarization\")\n",
    "\n",
    "\n",
    "    def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            data (:obj:):\n",
    "                includes the deserialized audio file as bytes\n",
    "        Return:\n",
    "            A :obj:`dict`:. base64 encoded image\n",
    "        \"\"\"\n",
    "        # process input\n",
    "        inputs = data.pop(\"inputs\", data)\n",
    "        parameters = data.pop(\"parameters\", None) #  min_speakers=2, max_speakers=5\n",
    "\n",
    "        \n",
    "        # prepare pynannote input\n",
    "        audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)\n",
    "        audio_tensor= torch.from_numpy(audio_nparray).unsqueeze(0)\n",
    "        pyannote_input = {\"waveform\": audio_tensor, \"sample_rate\": SAMPLE_RATE}\n",
    "        \n",
    "        # apply pretrained pipeline\n",
    "        # pass inputs with all kwargs in data\n",
    "        if parameters is not None:\n",
    "            diarization = self.pipeline(pyannote_input, **parameters)\n",
    "        else:\n",
    "            diarization = self.pipeline(pyannote_input)\n",
    "\n",
    "        # postprocess the prediction\n",
    "        processed_diarization = [\n",
    "            {\"label\": str(label), \"start\": str(segment.start), \"stop\": str(segment.end)}\n",
    "            for segment, _, label in diarization.itertracks(yield_label=True)\n",
    "        ]\n",
    "        \n",
    "        return {\"diarization\": processed_diarization}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "test custom pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from handler import EndpointHandler\n",
    "\n",
    "# init handler\n",
    "my_handler = EndpointHandler(path=\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import base64\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "import json\n",
    "\n",
    "# file reader\n",
    "with open(\"sample.wav\", \"rb\") as f:\n",
    "  request = {\"inputs\": f.read()}\n",
    "\n",
    "# test the handler\n",
    "pred = my_handler(request)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'diarization': [{'label': 'SPEAKER_01',\n",
       "   'start': '0.4978125',\n",
       "   'stop': '1.3921875'},\n",
       "  {'label': 'SPEAKER_01', 'start': '1.8984375', 'stop': '2.7590624999999998'},\n",
       "  {'label': 'SPEAKER_02', 'start': '2.9953125', 'stop': '3.5015625000000004'},\n",
       "  {'label': 'SPEAKER_01',\n",
       "   'start': '3.5690625000000002',\n",
       "   'stop': '4.311562500000001'},\n",
       "  {'label': 'SPEAKER_02', 'start': '4.6153125', 'stop': '6.7753125'},\n",
       "  {'label': 'SPEAKER_00', 'start': '7.1128125', 'stop': '7.551562500000001'},\n",
       "  {'label': 'SPEAKER_02',\n",
       "   'start': '7.551562500000001',\n",
       "   'stop': '9.475312500000001'},\n",
       "  {'label': 'SPEAKER_02',\n",
       "   'start': '9.812812500000003',\n",
       "   'stop': '10.555312500000003'},\n",
       "  {'label': 'SPEAKER_00',\n",
       "   'start': '9.863437500000003',\n",
       "   'stop': '10.420312500000001'},\n",
       "  {'label': 'SPEAKER_03', 'start': '12.411562500000002', 'stop': '15.5503125'},\n",
       "  {'label': 'SPEAKER_00', 'start': '15.786562500000002', 'stop': '16.1409375'},\n",
       "  {'label': 'SPEAKER_01', 'start': '16.1409375', 'stop': '16.1578125'},\n",
       "  {'label': 'SPEAKER_00', 'start': '17.1534375', 'stop': '17.4234375'},\n",
       "  {'label': 'SPEAKER_01', 'start': '17.7440625', 'stop': '20.3596875'},\n",
       "  {'label': 'SPEAKER_01', 'start': '20.6128125', 'stop': '20.6634375'},\n",
       "  {'label': 'SPEAKER_00', 'start': '20.6634375', 'stop': '20.8490625'},\n",
       "  {'label': 'SPEAKER_01', 'start': '20.8490625', 'stop': '20.8828125'},\n",
       "  {'label': 'SPEAKER_01', 'start': '21.1021875', 'stop': '22.1315625'},\n",
       "  {'label': 'SPEAKER_02', 'start': '22.4521875', 'stop': '22.7053125'},\n",
       "  {'label': 'SPEAKER_02', 'start': '23.2115625', 'stop': '23.4815625'},\n",
       "  {'label': 'SPEAKER_01', 'start': '23.4815625', 'stop': '24.0215625'},\n",
       "  {'label': 'SPEAKER_02', 'start': '24.3253125', 'stop': '25.5065625'},\n",
       "  {'label': 'SPEAKER_01', 'start': '25.8440625', 'stop': '27.3121875'},\n",
       "  {'label': 'SPEAKER_02', 'start': '27.3121875', 'stop': '27.4978125'},\n",
       "  {'label': 'SPEAKER_01', 'start': '29.7253125', 'stop': '29.9615625'}]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('dev': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}