unknown commited on
Commit
47ce5f0
·
1 Parent(s): bca6517

ADD application file

Browse files
Files changed (2) hide show
  1. app.py +91 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
3
+ import torch
4
+ import torchaudio
5
+ from io import BytesIO
6
+
7
+ # Hugging Face Model Hub'dan modelni yuklash
8
+ model_name = "Mrkomiljon/voiceGUARD/wav2vec2_finetuned_model" # Hugging Face Model Hub'dagi modelning to'liq nomi
9
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
10
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
11
+ model.eval()
12
+
13
+ # Device setup
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.to(device)
16
+
17
+ # Define label mapping
18
+ id2label = {
19
+ 0: "diffwave",
20
+ 1: "melgan",
21
+ 2: "parallel_wave_gan",
22
+ 3: "Real",
23
+ 4: "wavegrad",
24
+ 5: "wavnet",
25
+ 6: "wavernn"
26
+ }
27
+
28
+ # Define the prediction function
29
+ def predict_audio(file):
30
+ target_sample_rate = 16000 # Model's expected sample rate
31
+ max_length = target_sample_rate * 10 # 10 seconds in samples
32
+
33
+ try:
34
+ # Load the audio file
35
+ audio_bytes = file.read()
36
+ waveform, sample_rate = torchaudio.load(BytesIO(audio_bytes))
37
+
38
+ # Resample if the sample rate doesn't match the model's expected rate
39
+ if sample_rate != target_sample_rate:
40
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
41
+ waveform = resampler(waveform)
42
+
43
+ # Truncate or pad the waveform to ensure consistent input length
44
+ if waveform.size(1) > max_length:
45
+ waveform = waveform[:, :max_length] # Truncate
46
+ elif waveform.size(1) < max_length:
47
+ waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.size(1))) # Pad
48
+ if waveform.ndim > 1:
49
+ waveform = waveform[0]
50
+
51
+ # Process the audio file
52
+ inputs = processor(
53
+ waveform.squeeze().numpy(),
54
+ sampling_rate=target_sample_rate,
55
+ return_tensors="pt",
56
+ padding=True
57
+ )
58
+ input_values = inputs["input_values"].to(device)
59
+
60
+ # Perform inference
61
+ with torch.no_grad():
62
+ logits = model(input_values).logits
63
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
64
+ predicted_label = torch.argmax(probabilities, dim=-1).item()
65
+ confidence = probabilities[0, predicted_label].item()
66
+
67
+ # Map label to class name
68
+ class_name = id2label.get(predicted_label, "Unknown Class")
69
+
70
+ return {
71
+ "Class": class_name,
72
+ "Confidence": round(confidence * 100, 2)
73
+ }
74
+ except Exception as e:
75
+ return {"error": f"Error processing the audio file: {str(e)}"}
76
+
77
+ # Create the Gradio interface
78
+ iface = gr.Interface(
79
+ fn=predict_audio,
80
+ inputs=gr.Audio(type="file"),
81
+ outputs=[
82
+ gr.Label(label="Predicted Class"),
83
+ gr.Label(label="Confidence")
84
+ ],
85
+ title="Audio Classification with Wav2Vec2",
86
+ description="Upload an audio file to classify it into one of the predefined categories."
87
+ )
88
+
89
+ # Launch the Gradio app
90
+ if __name__ == "__main__":
91
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torchaudio
3
+ transformers
4
+ torch