Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn as nn | |
from transformers import Wav2Vec2Processor | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
import audiofile | |
import audresample | |
model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender" | |
duration = 1 # limit processing of audio | |
class ModelHead(nn.Module): | |
r"""Classification head.""" | |
def __init__(self, config, num_labels): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
r"""Speech emotion classifier.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.age = ModelHead(config, 1) | |
self.gender = ModelHead(config, 3) | |
self.init_weights() | |
def forward( | |
self, | |
input_values, | |
): | |
outputs = self.wav2vec2(input_values) | |
hidden_states = outputs[0] | |
hidden_states = torch.mean(hidden_states, dim=1) | |
logits_age = self.age(hidden_states) | |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
return hidden_states, logits_age, logits_gender | |
# load model from hub | |
device = 0 if torch.cuda.is_available() else "cpu" | |
processor = Wav2Vec2Processor.from_pretrained(model_name) | |
model = AgeGenderModel.from_pretrained(model_name) | |
def process_func(x: np.ndarray, sampling_rate: int) -> dict: | |
r"""Predict age and gender or extract embeddings from raw audio signal.""" | |
# run through processor to normalize signal | |
# always returns a batch, so we just get the first entry | |
# then we put it on the device | |
y = processor(x, sampling_rate=sampling_rate) | |
y = y['input_values'][0] | |
y = y.reshape(1, -1) | |
y = torch.from_numpy(y).to(device) | |
# run through model | |
with torch.no_grad(): | |
y = model(y) | |
y = torch.hstack([y[1], y[2]]) | |
# convert to numpy | |
y = y.detach().cpu().numpy() | |
# convert to dict | |
y = { | |
"age": 100 * y[0][0], | |
"female": y[0][1], | |
"male": y[0][2], | |
"child": y[0][3], | |
} | |
return y | |
def recognize(input_file): | |
# sampling_rate, signal = input_microphone | |
# signal = signal.astype(np.float32, order="C") / 32768.0 | |
if input_file: | |
signal, sampling_rate = audiofile.read(file, duration=duration) | |
else: | |
raise gr.Error( | |
"No audio file submitted! " | |
"Please upload or record an audio file " | |
"before submitting your request." | |
) | |
# Resample to sampling rate supported byu the models | |
target_rate = 16000 | |
signal = audresample.resample(signal, sampling_rate, target_rate) | |
age_gender = process_func(signal, target_rate) | |
age = f"{round(age_gender['age'])} years" | |
gender = {k: v for k, v in age_gender.items() if k != "age"} | |
return age, gender | |
outputs = gr.Label() | |
title = "audEERING age and gender recognition" | |
description = ( | |
"Recognize age and gender of a microphone recording or audio file. " | |
f"Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})." | |
) | |
allow_flagging = "never" | |
# microphone = gr.Interface( | |
# fn=recognize, | |
# inputs=gr.Audio(sources="microphone", type="filepath"), | |
# outputs=outputs, | |
# title=title, | |
# description=description, | |
# allow_flagging=allow_flagging, | |
# ) | |
# file = gr.Interface( | |
# fn=recognize, | |
# inputs=gr.Audio(sources="upload", type="filepath", label="Audio file"), | |
# outputs=outputs, | |
# title=title, | |
# description=description, | |
# allow_flagging=allow_flagging, | |
# ) | |
# | |
# # demo = gr.TabbedInterface([microphone, file], ["Microphone", "Audio file"]) | |
# # demo.queue().launch() | |
# # demo.launch() | |
# file.launch() | |
def toggle_input(choice): | |
if choice == "microphone": | |
return gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True) | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
with gr.Tab(label="Input"): | |
with gr.Row(): | |
with gr.Column(): | |
# input_selection = gr.Radio( | |
# ["microphone", "file"], | |
# value="file", | |
# label="How would you like to upload your audio?", | |
# ) | |
input_file = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Audio file", | |
) | |
# input_microphone = gr.Audio( | |
# sources="microphone", | |
# type="filepath", | |
# label="Microphone", | |
# ) | |
# output_selector = gr.Dropdown( | |
# choices=["age", "gender"], | |
# label="Output", | |
# value="age", | |
# ) | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
output_age = gr.Textbox(label="Age") | |
output_gender = gr.Label(label="gender") | |
# def update_output(output_selector): | |
# """Set different output types for different model outputs.""" | |
# if output_selector == "gender": | |
# output = gr.Label(label="gender") | |
# return output | |
# output_selector.input(update_output, output_selector, output) | |
outputs = [output_age, output_gender] | |
# input_selection.change(toggle_input, input_selection, inputs) | |
# input_microphone.change(lambda x: x, input_microphone, outputs) | |
# input_file.change(lambda x: x, input_file, outputs) | |
submit_btn.click(recognize, input_file, outputs) | |
demo.launch(debug=True) | |