Spaces:
Runtime error
Runtime error
Pavankalyan
commited on
Commit
·
fce051b
1
Parent(s):
c979e2a
Update app.py
Browse files
app.py
CHANGED
@@ -102,49 +102,48 @@ class LightningModel(pl.LightningModule):
|
|
102 |
predicted_class = torch.argmax(output, dim=1)
|
103 |
return predicted_class
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
102 |
predicted_class = torch.argmax(output, dim=1)
|
103 |
return predicted_class
|
104 |
|
105 |
+
|
106 |
+
print(torch.cuda.mem_get_info())
|
107 |
+
|
108 |
+
model = LightningModel()
|
109 |
+
|
110 |
+
run_name = "wav2vec"
|
111 |
+
|
112 |
+
checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt"
|
113 |
+
checkpoint = torch.load(checkpoint_path)
|
114 |
+
model.load_state_dict(checkpoint['state_dict'])
|
115 |
+
trainer = Trainer(
|
116 |
+
gpus=1
|
117 |
+
)
|
118 |
+
|
119 |
+
#trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
|
120 |
+
#trainer.test(model,dataloaders=testloader,verbose=True)
|
121 |
+
|
122 |
+
#with torch.no_grad():
|
123 |
+
# y_hat = model(wav_tensor)
|
124 |
+
|
125 |
+
def trabscribe(audio):
|
126 |
+
wav_tensor,_ = audio
|
127 |
+
wav_tensor = resmaple(wav_tensor)
|
128 |
+
#model = model.to('cuda')
|
129 |
+
y_hat = model.predict(wav_tensor)
|
130 |
+
labels = {0:"branch_address : enquiry about bank branch location",
|
131 |
+
1:"activate_card : enquiry about activating card products",
|
132 |
+
2:"past_transactions : enquiry about past transactions in a specific time period",
|
133 |
+
3:"dispatch_status : enquiry about the dispatch status of card products",
|
134 |
+
4:"outstanding_balance : enquiry about outstanding balance on card products",
|
135 |
+
5:"card_issue : report about an issue with using card products",
|
136 |
+
6:"ifsc_code : enquiry about IFSC code of bank branch",
|
137 |
+
7:"generate_pin : enquiry about changing or generating a new pin for their card product",
|
138 |
+
8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
|
139 |
+
9:"loan_query : enquiry about different kinds of loans",
|
140 |
+
10:"balance_enquiry : enquiry about bank account balance",
|
141 |
+
11:"change_limit : enquiry about changing the limit for card products",
|
142 |
+
12:"block : enquiry about blocking card or banking product",
|
143 |
+
13:"lost : report about losing a card product}
|
144 |
+
return labels[y_hat]
|
145 |
+
|
146 |
+
print(y_hat)
|
147 |
+
get_intent = gr.Interface(fn = transcribe,
|
148 |
+
gr.Audio(source="microphone"), outputs="text").launch()
|
149 |
+
|
|