hagenw commited on
Commit
ba45a7b
·
1 Parent(s): fa18e4b
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -11,12 +11,13 @@ import audiofile
11
  import audresample
12
 
13
 
 
14
  model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
15
  duration = 1 # limit processing of audio
16
 
17
 
18
- class ModelHead(nn.Module):
19
- r"""Classification head."""
20
 
21
  def __init__(self, config, num_labels):
22
 
@@ -39,7 +40,7 @@ class ModelHead(nn.Module):
39
 
40
 
41
  class AgeGenderModel(Wav2Vec2PreTrainedModel):
42
- r"""Speech emotion classifier."""
43
 
44
  def __init__(self, config):
45
 
@@ -47,8 +48,8 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
47
 
48
  self.config = config
49
  self.wav2vec2 = Wav2Vec2Model(config)
50
- self.age = ModelHead(config, 1)
51
- self.gender = ModelHead(config, 3)
52
  self.init_weights()
53
 
54
  def forward(
@@ -67,7 +68,6 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
67
 
68
 
69
  # load model from hub
70
- device = 0 if torch.cuda.is_available() else "cpu"
71
  processor = Wav2Vec2Processor.from_pretrained(model_name)
72
  model = AgeGenderModel.from_pretrained(model_name)
73
 
@@ -105,14 +105,14 @@ def process_func(x: np.ndarray, sampling_rate: int) -> dict:
105
  def recognize(input_file):
106
  # sampling_rate, signal = input_microphone
107
  # signal = signal.astype(np.float32, order="C") / 32768.0
108
- if input_fileis not None:
109
- signal, sampling_rate = audiofile.read(input_file, duration=duration)
110
- else:
111
  raise gr.Error(
112
  "No audio file submitted! "
113
  "Please upload or record an audio file "
114
  "before submitting your request."
115
  )
 
 
116
  # Resample to sampling rate supported byu the models
117
  target_rate = 16000
118
  signal = audresample.resample(signal, sampling_rate, target_rate)
 
11
  import audresample
12
 
13
 
14
+ device = 0 if torch.cuda.is_available() else "cpu"
15
  model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
16
  duration = 1 # limit processing of audio
17
 
18
 
19
+ class AgeGenderHead(nn.Module):
20
+ r"""Age-gender model head."""
21
 
22
  def __init__(self, config, num_labels):
23
 
 
40
 
41
 
42
  class AgeGenderModel(Wav2Vec2PreTrainedModel):
43
+ r"""Age-gender recognition model."""
44
 
45
  def __init__(self, config):
46
 
 
48
 
49
  self.config = config
50
  self.wav2vec2 = Wav2Vec2Model(config)
51
+ self.age = AgeGenderHead(config, 1)
52
+ self.gender = AgeGenderHead(config, 3)
53
  self.init_weights()
54
 
55
  def forward(
 
68
 
69
 
70
  # load model from hub
 
71
  processor = Wav2Vec2Processor.from_pretrained(model_name)
72
  model = AgeGenderModel.from_pretrained(model_name)
73
 
 
105
  def recognize(input_file):
106
  # sampling_rate, signal = input_microphone
107
  # signal = signal.astype(np.float32, order="C") / 32768.0
108
+ if input_file is None:
 
 
109
  raise gr.Error(
110
  "No audio file submitted! "
111
  "Please upload or record an audio file "
112
  "before submitting your request."
113
  )
114
+
115
+ signal, sampling_rate = audiofile.read(input_file, duration=duration)
116
  # Resample to sampling rate supported byu the models
117
  target_rate = 16000
118
  signal = audresample.resample(signal, sampling_rate, target_rate)