diff --git "a/running_tunisian.ipynb" "b/running_tunisian.ipynb" new file mode 100644--- /dev/null +++ "b/running_tunisian.ipynb" @@ -0,0 +1,972 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "import logging\n", + "import speechbrain as sb\n", + "from speechbrain.utils.distributed import run_on_main\n", + "from hyperpyyaml import load_hyperpyyaml\n", + "from pathlib import Path\n", + "import torchaudio.transforms as T\n", + "import torchaudio\n", + "import numpy as np\n", + "\n", + "from pyctcdecode import build_ctcdecoder\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "speechbrain.core - Beginning experiment!\n", + "speechbrain.core - Experiment folder: partly_frozen_splitted_wavlm/1986/\n", + "['', '', '', 'ق', 'ا', 'ع', 'د', 'ة', 'ت', 'ش', 'ي', 'ك', 'ه', 'ل', 'ح', 'ب', 'ن', 'ى', 'ر', 'ف', 'إ', 'س', 'أ', 'ض', 'ص', 'ط', 'خ', 'ج', 'ظ', 'ز', 'آ', 'ذ', 'غ', 'ث', 'ئ', 'ء', 'ؤ', 'ٱ', 'م', 'و', ' ']\n", + "41\n" + ] + } + ], + "source": [ + "hparams_file, run_opts, overrides = sb.parse_arguments([\"wavlm_partly_frozen.yaml\"])\n", + "\n", + "# If distributed_launch=True then\n", + "# create ddp_group with the right communication protocol\n", + "sb.utils.distributed.ddp_init_group(run_opts)\n", + "\n", + "with open(hparams_file) as fin:\n", + " hparams = load_hyperpyyaml(fin, overrides)\n", + "\n", + "# Create experiment directory\n", + "sb.create_experiment_directory(\n", + " experiment_directory=hparams[\"output_folder\"],\n", + " hyperparams_to_save=hparams_file,\n", + " overrides=overrides,\n", + ")\n", + "def read_labels_file(labels_file): \n", + " with open(labels_file, \"r\") as lf: \n", + " lines = lf.read().splitlines()\n", + " division = \"===\"\n", + " numbers = {}\n", + " for line in lines : \n", + " if division in line : \n", + " break\n", + " string, number = line.split(\"=>\")\n", + " number = int(number)\n", + " string = string[1:-2]\n", + " numbers[number] = string\n", + " return [numbers[x] for x in range(len(numbers))]\n", + "labels = read_labels_file(os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\"))\n", + "print(labels)\n", + "labels = [\"\"] + labels[1:]\n", + "print(len(labels))\n", + "\n", + "# Dataset prep (parsing Librispeech)\n", + "\n", + "resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)\n", + "\n", + "resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)\n", + "resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "speechbrain.dataio.encoder - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.\n", + "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n", + "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n", + "pyctcdecode.alphabet - Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?\n", + "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n", + "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n", + "speechbrain.core - 313.4M trainable parameters in ASR\n", + "we here\n", + "speechbrain.utils.checkpoints - Loading a checkpoint from partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00\n", + "لحم ماكلة بنينة كسكروت نظيف ورخيص\n" + ] + } + ], + "source": [ + "resamplers = {\"8000\": resampler_8000, \"44100\":resampler_44100, \"48000\": resampler_48000}\n", + "def dataio_prepare(hparams):\n", + " \"\"\"This function prepares the datasets to be used in the brain class.\n", + " It also defines the data processing pipeline through user-defined functions.\"\"\"\n", + " data_folder = hparams[\"data_folder\"]\n", + "\n", + " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n", + " csv_path=hparams[\"train_csv\"], replacements={\"data_root\": data_folder},\n", + " )\n", + "\n", + " if hparams[\"sorting\"] == \"ascending\":\n", + " # we sort training data to speed up training and get better results.\n", + " train_data = train_data.filtered_sorted(sort_key=\"duration\")\n", + " # when sorting do not shuffle in dataloader ! otherwise is pointless\n", + " hparams[\"train_dataloader_opts\"][\"shuffle\"] = False\n", + "\n", + " elif hparams[\"sorting\"] == \"descending\":\n", + " train_data = train_data.filtered_sorted(\n", + " sort_key=\"duration\", reverse=True\n", + " )\n", + " # when sorting do not shuffle in dataloader ! otherwise is pointless\n", + " hparams[\"train_dataloader_opts\"][\"shuffle\"] = False\n", + "\n", + " elif hparams[\"sorting\"] == \"random\":\n", + " pass\n", + "\n", + " else:\n", + " raise NotImplementedError(\n", + " \"sorting must be random, ascending or descending\"\n", + " )\n", + "\n", + " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n", + " csv_path=hparams[\"valid_csv\"], replacements={\"data_root\": data_folder},\n", + " )\n", + " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n", + "\n", + " # test is separate\n", + " test_datasets = {}\n", + " for csv_file in hparams[\"test_csv\"]:\n", + " name = Path(csv_file).stem\n", + " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(\n", + " csv_path=csv_file, replacements={\"data_root\": data_folder}\n", + " )\n", + " test_datasets[name] = test_datasets[name].filtered_sorted(\n", + " sort_key=\"duration\"\n", + " )\n", + "\n", + " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n", + "\n", + " # 2. Define audio pipeline:\n", + " @sb.utils.data_pipeline.takes(\"wav\", \"sr\")\n", + " @sb.utils.data_pipeline.provides(\"sig\")\n", + " def audio_pipeline(wav, sr):\n", + " sig = sb.dataio.dataio.read_audio(wav)\n", + " sig = resamplers[sr](sig)\n", + " return sig\n", + "\n", + " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n", + " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n", + "\n", + " # 3. Define text pipeline:\n", + " @sb.utils.data_pipeline.takes(\"wrd\")\n", + " @sb.utils.data_pipeline.provides(\n", + " \"wrd\", \"char_list\", \"tokens_list\", \"tokens_bos\", \"tokens_eos\", \"tokens\"\n", + " )\n", + " def text_pipeline(wrd):\n", + " yield wrd\n", + " char_list = list(wrd)\n", + " yield char_list\n", + " tokens_list = label_encoder.encode_sequence(char_list)\n", + " yield tokens_list\n", + " tokens_bos = torch.LongTensor([hparams[\"bos_index\"]] + (tokens_list))\n", + " yield tokens_bos\n", + " tokens_eos = torch.LongTensor(tokens_list + [hparams[\"eos_index\"]])\n", + " yield tokens_eos\n", + " tokens = torch.LongTensor(tokens_list)\n", + " yield tokens\n", + "\n", + " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n", + "\n", + " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n", + " special_labels = {\n", + " \"bos_label\": hparams[\"bos_index\"],\n", + " \"eos_label\": hparams[\"eos_index\"],\n", + " \"blank_label\": hparams[\"blank_index\"],\n", + " }\n", + " label_encoder.load_or_create(\n", + " path=lab_enc_file,\n", + " from_didatasets=[train_data],\n", + " output_key=\"char_list\",\n", + " special_labels=special_labels,\n", + " sequence_input=True,\n", + " )\n", + "\n", + " # 4. Set output:\n", + " sb.dataio.dataset.set_output_keys(\n", + " datasets,\n", + " [\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens_bos\", \"tokens_eos\", \"tokens\"],\n", + " )\n", + " return train_data, valid_data, test_datasets, label_encoder\n", + "\n", + "\n", + "class ASR(sb.Brain):\n", + " def compute_forward(self, batch, stage):\n", + " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n", + " batch = batch.to(self.device)\n", + " wavs, wav_lens = batch.sig\n", + " print(wavs)\n", + " tokens_bos, _ = batch.tokens_bos\n", + " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n", + "\n", + " # Forward pass\n", + " feats = self.modules.wav2vec2(wavs)\n", + " x = self.modules.enc(feats)\n", + " # Compute outputs\n", + " p_tokens = None\n", + " logits = self.modules.ctc_lin(x)\n", + " p_ctc = self.hparams.log_softmax(logits)\n", + " if stage != sb.Stage.TRAIN:\n", + " p_tokens = sb.decoders.ctc_greedy_decode(\n", + " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n", + " )\n", + " return p_ctc, wav_lens, p_tokens\n", + " \n", + " def treat_wav(self,sig): \n", + " feats = self.modules.wav2vec2(sig.to(self.device))\n", + " x = self.modules.enc(feats)\n", + " p_tokens = None\n", + " logits = self.modules.ctc_lin(x)\n", + " p_ctc = self.hparams.log_softmax(logits)\n", + " predicted_words =[]\n", + " for logs in p_ctc: \n", + " text = decoder.decode(logs.detach().cpu().numpy())\n", + " predicted_words.append(text.split(\" \"))\n", + " return \" \".join(predicted_words[0])\n", + "\n", + "\n", + "\n", + "\n", + " def compute_objectives(self, predictions, batch, stage):\n", + " \"\"\"Computes the loss (CTC+NLL) given predictions and targets.\"\"\"\n", + "\n", + " p_ctc, wav_lens, predicted_tokens = predictions\n", + "\n", + " ids = batch.id\n", + " tokens_eos, tokens_eos_lens = batch.tokens_eos\n", + " tokens, tokens_lens = batch.tokens\n", + "\n", + " if hasattr(self.modules, \"env_corrupt\") and stage == sb.Stage.TRAIN:\n", + " tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)\n", + " tokens_eos_lens = torch.cat(\n", + " [tokens_eos_lens, tokens_eos_lens], dim=0\n", + " )\n", + " tokens = torch.cat([tokens, tokens], dim=0)\n", + " tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)\n", + "\n", + " loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n", + " loss = loss_ctc\n", + " if stage != sb.Stage.TRAIN:\n", + " # Decode token terms to words\n", + " predicted_words =[]\n", + " for logs in p_ctc: \n", + " text = decoder.decode(logs.detach().cpu().numpy())\n", + " predicted_words.append(text.split(\" \"))\n", + "\n", + " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n", + " self.wer_metric.append(ids, predicted_words, target_words)\n", + " self.cer_metric.append(ids, predicted_words, target_words)\n", + "\n", + " return loss\n", + "\n", + " def fit_batch(self, batch):\n", + " \"\"\"Train the parameters given a single batch in input\"\"\"\n", + " predictions = self.compute_forward(batch, sb.Stage.TRAIN)\n", + " loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)\n", + " loss.backward()\n", + " if self.check_gradients(loss):\n", + " self.wav2vec_optimizer.step()\n", + " self.model_optimizer.step()\n", + "\n", + " self.wav2vec_optimizer.zero_grad()\n", + " self.model_optimizer.zero_grad()\n", + "\n", + " return loss.detach()\n", + "\n", + " def evaluate_batch(self, batch, stage):\n", + " \"\"\"Computations needed for validation/test batches\"\"\"\n", + " predictions = self.compute_forward(batch, stage=stage)\n", + " with torch.no_grad():\n", + " loss = self.compute_objectives(predictions, batch, stage=stage)\n", + " return loss.detach()\n", + "\n", + " def on_stage_start(self, stage, epoch):\n", + " \"\"\"Gets called at the beginning of each epoch\"\"\"\n", + " if stage != sb.Stage.TRAIN:\n", + " self.cer_metric = self.hparams.cer_computer()\n", + " self.wer_metric = self.hparams.error_rate_computer()\n", + "\n", + " def on_stage_end(self, stage, stage_loss, epoch):\n", + " \"\"\"Gets called at the end of an epoch.\"\"\"\n", + " # Compute/store important stats\n", + " stage_stats = {\"loss\": stage_loss}\n", + " if stage == sb.Stage.TRAIN:\n", + " self.train_stats = stage_stats\n", + " else:\n", + " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n", + " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n", + "\n", + " # Perform end-of-iteration things, like annealing, logging, etc.\n", + " if stage == sb.Stage.VALID:\n", + " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n", + " stage_stats[\"loss\"]\n", + " )\n", + " old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(\n", + " stage_stats[\"loss\"]\n", + " )\n", + " sb.nnet.schedulers.update_learning_rate(\n", + " self.model_optimizer, new_lr_model\n", + " )\n", + " sb.nnet.schedulers.update_learning_rate(\n", + " self.wav2vec_optimizer, new_lr_wav2vec\n", + " )\n", + " self.hparams.train_logger.log_stats(\n", + " stats_meta={\n", + " \"epoch\": epoch,\n", + " \"lr_model\": old_lr_model,\n", + " \"lr_wav2vec\": old_lr_wav2vec,\n", + " },\n", + " train_stats=self.train_stats,\n", + " valid_stats=stage_stats,\n", + " )\n", + " self.checkpointer.save_and_keep_only(\n", + " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n", + " )\n", + " elif stage == sb.Stage.TEST:\n", + " self.hparams.train_logger.log_stats(\n", + " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n", + " test_stats=stage_stats,\n", + " )\n", + " with open(self.hparams.wer_file, \"w\") as w:\n", + " self.wer_metric.write_stats(w)\n", + "\n", + " def init_optimizers(self):\n", + " \"Initializes the wav2vec2 optimizer and model optimizer\"\n", + " self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(\n", + " self.modules.wav2vec2.parameters()\n", + " )\n", + " self.model_optimizer = self.hparams.model_opt_class(\n", + " self.hparams.model.parameters()\n", + " )\n", + "\n", + " if self.checkpointer is not None:\n", + " self.checkpointer.add_recoverable(\n", + " \"wav2vec_opt\", self.wav2vec_optimizer\n", + " )\n", + " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n", + "\n", + "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n", + "\n", + "train_data, valid_data, test_datasets, label_encoder = dataio_prepare(\n", + " hparams\n", + " )\n", + "\n", + "\n", + "# We dynamicaly add the tokenizer to our brain class.\n", + "# NB: This tokenizer corresponds to the one used for the LM!!\n", + "decoder = build_ctcdecoder(\n", + " labels,\n", + " kenlm_model_path=\"tunisian.arpa\", # either .arpa or .bin file\n", + " alpha=0.5, # tuned on a val set\n", + " beta=1, # tuned on a val set\n", + ")\n", + "\n", + "asr_brain = ASR(\n", + " modules=hparams[\"modules\"],\n", + " hparams=hparams,\n", + " run_opts=run_opts,\n", + " checkpointer=hparams[\"checkpointer\"],\n", + ")\n", + "asr_brain.device= \"cpu\"\n", + "asr_brain.modules.to(\"cpu\")\n", + "asr_brain.tokenizer = label_encoder\n", + "\n", + "# Testing\n", + "real = False\n", + "if real : \n", + " for k in test_datasets.keys(): # keys are test_clean, test_other etc\n", + " asr_brain.hparams.wer_file = os.path.join(\n", + " hparams[\"output_folder\"], \"wer_{}.txt\".format(k)\n", + " )\n", + " asr_brain.evaluate(\n", + " test_datasets[k], test_loader_kwargs=hparams[\"test_dataloader_opts\"]\n", + " )\n", + "\n", + "else : \n", + " print(\"we here\")\n", + " def treat_wav_file(wav, sr=48000, resamplers = resamplers,asr=asr_brain, device=\"cpu\") :\n", + " tensor_wav = torch.tensor(wav).to(device)\n", + " resampled = resamplers[sr](tensor_wav)\n", + " sentence = asr_brain.treat_wav(resampled)\n", + " return sentence\n", + " from enum import Enum, auto\n", + " class Stage(Enum):\n", + " TRAIN = auto()\n", + " VALID = auto()\n", + " TEST = auto()\n", + "\n", + " asr_brain.on_evaluate_start()\n", + " asr_brain.modules.eval()\n", + " avg_test_loss = 0.0\n", + " with torch.no_grad():\n", + "\n", + " sig = sb.dataio.dataio.read_audio(\"samples/Salah1.wav\")\n", + " sr=48000\n", + " x=np.expand_dims(sig,0)\n", + " print(treat_wav_file(x, str(sr)))\n", + "\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "abfcbf3f046645e184755f4b492649b9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "AudioRecorder(audio=Audio(value=b'', format='webm'), stream=CameraStream(constraints={'audio': True, 'video': …" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from ipywebrtc import AudioRecorder, CameraStream\n", + "import torchaudio\n", + "from IPython.display import Audio\n", + "camera = CameraStream(constraints={'audio': True,'video':False})\n", + "recorder = AudioRecorder(stream=camera)\n", + "recorder\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gr.Interface(\n", + " fn=transcribe, \n", + " inputs=gr.Audio(source=\"microphone\", type=\"filepath\"), \n", + " outputs=\"text\").launch()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 4.8828125e-04\n", + " 3.6621094e-04 -6.1035156e-05]]\n" + ] + }, + { + "ename": "TypeError", + "evalue": "expected str, bytes or os.PathLike object, not numpy.ndarray", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0msr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m48000\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpand_dims\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtreat_wav_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtreat_wav_file\u001b[0;34m(wav, resamplers, asr, device)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtreat_wav_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwav\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresamplers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresamplers\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0masr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0masr_brain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwav\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0msig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorchaudio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwav\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mtensor_wav\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mresampled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresamplers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_wav\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/salah/lib/python3.8/site-packages/torchaudio/backend/sox_io_backend.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(filepath, frame_offset, num_frames, normalize, channels_first, format)\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0mbuffer_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m )\n\u001b[0;32m--> 240\u001b[0;31m \u001b[0mfilepath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfspath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 241\u001b[0m ret = torch.ops.torchaudio.sox_io_load_audio_file(\n\u001b[1;32m 242\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframe_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_frames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnormalize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchannels_first\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not numpy.ndarray" + ] + } + ], + "source": [ + "\n", + "with open('recording.webm', 'wb') as f:\n", + " f.write(recorder.audio.value)\n", + "!ffmpeg -i recording.webm -ac 1 -f wav file.wav -y -hide_banner -loglevel panic\n", + "sig, sr = torchaudio.load(\"file.wav\")\n", + "\n", + "sig = sb.dataio.dataio.read_audio(\"file.wav\")\n", + "sr=48000\n", + "x=np.expand_dims(sig,0)\n", + "print(treat_wav_file(x, str(sr)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7861\n", + "Running on public URL: https://383249a2-2010-4891.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiovtx9sb1u.wav\n" + ] + } + ], + "source": [ + "import gradio as gr\n", + "def treat_wav_file(wav, resamplers = resamplers,asr=asr_brain, device=\"cpu\") :\n", + " print(wav)\n", + " sig, sr = torchaudio.load(wav)\n", + " tensor_wav = sig.to(device)\n", + " resampled = resamplers[str(sr)](tensor_wav)\n", + " sentence = asr_brain.treat_wav(resampled)\n", + " return sentence\n", + "\n", + "gr.Interface(\n", + " fn=treat_wav_file, \n", + " inputs=gr.Audio(source=\"microphone\", type=\"filepath\"), \n", + " outputs=\"text\").launch(share=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7871\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/Salah9c_xj74dq.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiowh6q1wxw.wav\n" + ] + } + ], + "source": [ + "import gradio as gr\n", + "def treat_wav_file(file_mic, file_upload, resamplers = resamplers,asr=asr_brain, device=\"cpu\") :\n", + " \n", + " if (file_mic is not None) and (file_upload is not None):\n", + " warn_output = \"WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\\n\"\n", + " wav = file_mic\n", + " elif (file_mic is None) and (file_upload is None):\n", + " return \"ERROR: You have to either use the microphone or upload an audio file\"\n", + " elif file_mic is not None:\n", + " wav = file_mic\n", + " else:\n", + " wav = file_upload\n", + " print(wav)\n", + " sig, sr = torchaudio.load(wav)\n", + " tensor_wav = sig.to(device)\n", + " resampled = resamplers[str(sr)](tensor_wav)\n", + " sentence = asr_brain.treat_wav(resampled)\n", + " return sentence\n", + "\n", + "gr.Interface(\n", + " fn=treat_wav_file, \n", + " inputs=[gr.inputs.Audio(source=\"microphone\", type='filepath', optional=True),\n", + " gr.inputs.Audio(source=\"upload\", type='filepath', optional=True)]\n", + " ,outputs=\"text\").launch()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "44100\n", + "torch.Size([2, 168521728])\n", + "torch.Size([168521728])\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n", + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiowj68eq5y.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/temp_chunk5czavqiw.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiobwixndcj.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiouub9zh5g.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audio0wmnumvp.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiozzjhuq_i.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audio7zdmx3nh.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiowqco3s07.wav\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:236: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", + " warnings.warn(warning.format(data.dtype))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audiollawbske.wav\n" + ] + } + ], + "source": [ + "import librosa\n", + "import soundfile as sf\n", + "\n", + "filein = \"../../Tunisian_data/saffi/Saffi Kalbek S02 Episode 03 30-09-2020 Partie 03.mp4\"\n", + "f, sr = torchaudio.load(filein)\n", + "print(sr)\n", + "print(f.shape)\n", + "mono_audio =torch.mean(f, dim=0)\n", + "x= np.random.randint(1000000,57000000)\n", + "print(mono_audio.shape)\n", + "chunk = mono_audio[x: x+44100*30]\n", + "torchaudio.save(\"temp_chunk.wav\", chunk.unsqueeze(0), sr)\n", + "import IPython\n", + "IPython.display.Audio(\"temp_chunk.wav\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/inputs.py:319: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components\n", + " warnings.warn(\n", + "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/deprecation.py:40: UserWarning: `optional` parameter is deprecated, and it has no effect\n", + " warnings.warn(value)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7876\n", + "Running on public URL: https://cd65009e-8daa-4c22.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/audioxbdtypb5.wav\n" + ] + } + ], + "source": [ + "import gradio as gr\n", + "def treat_wav_file(file_mic, file_upload, resamplers = resamplers,asr=asr_brain, device=\"cpu\") :\n", + " \n", + " if (file_mic is not None) and (file_upload is not None):\n", + " warn_output = \"WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\\n\"\n", + " wav = file_mic\n", + " elif (file_mic is None) and (file_upload is None):\n", + " return \"ERROR: You have to either use the microphone or upload an audio file\"\n", + " elif file_mic is not None:\n", + " wav = file_mic\n", + " else:\n", + " wav = file_upload\n", + " print(wav)\n", + " sig, sr = torchaudio.load(wav)\n", + " tensor_wav = sig.to(device)\n", + " resampled = resamplers[str(sr)](tensor_wav)\n", + " sentence = asr_brain.treat_wav(resampled)\n", + " return sentence\n", + "\n", + "gr.Interface(\n", + " fn=treat_wav_file, \n", + " inputs=[gr.inputs.Audio(source=\"microphone\", type='filepath', optional=True),\n", + " gr.inputs.Audio(source=\"upload\", type='filepath', optional=True)]\n", + " ,outputs=\"text\").launch(share=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}