Spaces:
Runtime error
Runtime error
from model import Wav2VecModel | |
from dataset import S2IDataset, collate_fn | |
import requests | |
requests.packages.urllib3.disable_warnings() | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger | |
# SEED | |
SEED=100 | |
pl.utilities.seed.seed_everything(SEED) | |
torch.manual_seed(SEED) | |
import os | |
os.environ['WANDB_MODE'] = 'online' | |
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"]="1" | |
class LightningModel(pl.LightningModule): | |
def __init__(self,): | |
super().__init__() | |
self.model = Wav2VecModel() | |
def forward(self, x): | |
return self.model(x) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5) | |
return [optimizer] | |
def loss_fn(self, prediction, targets): | |
return nn.CrossEntropyLoss()(prediction, targets) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y = y.view(-1) | |
logits = self(x) | |
probs = F.softmax(logits, dim=1) | |
loss = self.loss_fn(logits, y) | |
winners = logits.argmax(dim=1) | |
corrects = (winners == y) | |
acc = corrects.sum().float()/float(logits.size(0)) | |
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True) | |
torch.cuda.empty_cache() | |
return { | |
'loss':loss, | |
'acc':acc | |
} | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y = y.view(-1) | |
logits = self(x) | |
loss = self.loss_fn(logits, y) | |
winners = logits.argmax(dim=1) | |
corrects = (winners == y) | |
acc = corrects.sum().float() / float( logits.size(0)) | |
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True) | |
return {'val_loss':loss, | |
'val_acc':acc, | |
} | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y = y.view(-1) | |
logits = self(x) | |
loss = self.loss_fn(logits, y) | |
winners = logits.argmax(dim=1) | |
corrects = (winners == y) | |
acc = corrects.sum().float() / float( logits.size(0)) | |
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True) | |
return {'val_loss':loss, | |
'val_acc':acc, | |
} | |
def predict(self, wav): | |
self.eval() | |
with torch.no_grad(): | |
output = self.forward(wav) | |
predicted_class = torch.argmax(output, dim=1) | |
return predicted_class | |
print(torch.cuda.mem_get_info()) | |
model = LightningModel() | |
run_name = "wav2vec" | |
checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt" | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint['state_dict']) | |
trainer = Trainer( | |
gpus=1 | |
) | |
#trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader) | |
#trainer.test(model,dataloaders=testloader,verbose=True) | |
#with torch.no_grad(): | |
# y_hat = model(wav_tensor) | |
def trabscribe(audio): | |
wav_tensor,_ = audio | |
wav_tensor = resmaple(wav_tensor) | |
#model = model.to('cuda') | |
y_hat = model.predict(wav_tensor) | |
labels = {0:"branch_address : enquiry about bank branch location", | |
1:"activate_card : enquiry about activating card products", | |
2:"past_transactions : enquiry about past transactions in a specific time period", | |
3:"dispatch_status : enquiry about the dispatch status of card products", | |
4:"outstanding_balance : enquiry about outstanding balance on card products", | |
5:"card_issue : report about an issue with using card products", | |
6:"ifsc_code : enquiry about IFSC code of bank branch", | |
7:"generate_pin : enquiry about changing or generating a new pin for their card product", | |
8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction", | |
9:"loan_query : enquiry about different kinds of loans", | |
10:"balance_enquiry : enquiry about bank account balance", | |
11:"change_limit : enquiry about changing the limit for card products", | |
12:"block : enquiry about blocking card or banking product", | |
13:"lost : report about losing a card product} | |
return labels[y_hat] | |
print(y_hat) | |
get_intent = gr.Interface(fn = transcribe, | |
gr.Audio(source="microphone"), outputs="text").launch() | |