jcho02 commited on
Commit
73b065a
·
verified ·
1 Parent(s): a6d5ae5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -2
app.py CHANGED
@@ -8,14 +8,58 @@ from datasets import load_dataset, DatasetDict, Audio
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
10
 
11
- # [Your existing code for device setup, config, SpeechInferenceDataset, SpeechClassifier]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Prepare data function
14
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
15
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
16
  inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
17
  input_features = inputs.input_features
18
- decoder_input_ids = torch.tensor([[1, 1]])
19
  return input_features.to(device), decoder_input_ids.to(device)
20
 
21
  # Prediction function
 
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
10
 
11
+ # Ensure you have the device setup (cuda or cpu)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Define the config for your model
15
+ config = {"encoder": "openai/whisper-base", "num_labels": 2}
16
+
17
+ # Define data class
18
+ class SpeechInferenceDataset(Dataset):
19
+ def __init__(self, audio_data, text_processor):
20
+ self.audio_data = audio_data
21
+ self.text_processor = text_processor
22
+
23
+ def __len__(self):
24
+ return len(self.audio_data)
25
+
26
+ def __getitem__(self, index):
27
+ inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
28
+ return_tensors="pt",
29
+ sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
30
+ input_features = inputs.input_features
31
+ decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
32
+ return input_features, decoder_input_ids
33
+
34
+ # Define model class
35
+ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
36
+ def __init__(self, config):
37
+ super(SpeechClassifier, self).__init__()
38
+ self.encoder = WhisperModel.from_pretrained(config["encoder"])
39
+ self.classifier = nn.Sequential(
40
+ nn.Linear(self.encoder.config.hidden_size, 4096),
41
+ nn.ReLU(),
42
+ nn.Linear(4096, 2048),
43
+ nn.ReLU(),
44
+ nn.Linear(2048, 1024),
45
+ nn.ReLU(),
46
+ nn.Linear(1024, 512),
47
+ nn.ReLU(),
48
+ nn.Linear(512, config["num_labels"])
49
+ )
50
+
51
+ def forward(self, input_features, decoder_input_ids):
52
+ outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
53
+ pooled_output = outputs['last_hidden_state'][:, 0, :]
54
+ logits = self.classifier(pooled_output)
55
+ return logits
56
 
57
  # Prepare data function
58
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
59
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
60
  inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
61
  input_features = inputs.input_features
62
+ decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
63
  return input_features.to(device), decoder_input_ids.to(device)
64
 
65
  # Prediction function