import gradio as gr import matplotlib.pyplot as plt import numpy as np import spaces import torch import torch.nn as nn from transformers import Wav2Vec2Processor from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel import audiofile import audresample device = 0 if torch.cuda.is_available() else "cpu" duration = 1 # limit processing of audio age_gender_model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender" expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" class AgeGenderHead(nn.Module): r"""Age-gender model head.""" def __init__(self, config, num_labels): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.final_dropout) self.out_proj = nn.Linear(config.hidden_size, num_labels) def forward(self, features, **kwargs): x = features x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class AgeGenderModel(Wav2Vec2PreTrainedModel): r"""Age-gender recognition model.""" def __init__(self, config): super().__init__(config) self.config = config self.wav2vec2 = Wav2Vec2Model(config) self.age = AgeGenderHead(config, 1) self.gender = AgeGenderHead(config, 3) self.init_weights() def forward( self, input_values, ): outputs = self.wav2vec2(input_values) hidden_states = outputs[0] hidden_states = torch.mean(hidden_states, dim=1) logits_age = self.age(hidden_states) logits_gender = torch.softmax(self.gender(hidden_states), dim=1) return hidden_states, logits_age, logits_gender class ExpressionHead(nn.Module): r"""Expression model head.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.final_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class ExpressionModel(Wav2Vec2PreTrainedModel): r"""speech expression model.""" def __init__(self, config): super().__init__(config) self.config = config self.wav2vec2 = Wav2Vec2Model(config) self.classifier = ExpressionHead(config) self.init_weights() def forward(self, input_values): outputs = self.wav2vec2(input_values) hidden_states = outputs[0] hidden_states = torch.mean(hidden_states, dim=1) logits = self.classifier(hidden_states) return hidden_states, logits # Load models from hub age_gender_processor = Wav2Vec2Processor.from_pretrained(age_gender_model_name) age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name) expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name) expression_model = ExpressionModel.from_pretrained(expression_model_name) def process_func(x: np.ndarray, sampling_rate: int) -> dict: r"""Predict age and gender or extract embeddings from raw audio signal.""" # run through processor to normalize signal # always returns a batch, so we just get the first entry # then we put it on the device results = [] for processor, model in zip( [age_gender_processor, expression_processor], [age_gender_model, expression_model], ): y = processor(x, sampling_rate=sampling_rate) y = y['input_values'][0] y = y.reshape(1, -1) y = torch.from_numpy(y).to(device) # run through model with torch.no_grad(): y = model(y) if len(y) == 3: # Age-gender model y = torch.hstack([y[1], y[2]]) else: # Expression model y = y[1] # convert to numpy y = y.detach().cpu().numpy() results.append(y[0]) # Plot A/D/V values plot_expression(results[1][0], results[1][1], results[1][2]) expression_file = "expression.png" plt.savefig(expression_file) return ( f"{round(100 * results[0][0])} years", # age { "female": results[0][1], "male": results[0][2], "child": results[0][3], }, expression_file, # { # "arousal": results[1][0], # "dominance": results[1][1], # "valence": results[1][2], # } ) @spaces.GPU def recognize(input_file): # sampling_rate, signal = input_microphone # signal = signal.astype(np.float32, order="C") / 32768.0 if input_file is None: raise gr.Error( "No audio file submitted! " "Please upload or record an audio file " "before submitting your request." ) signal, sampling_rate = audiofile.read(input_file, duration=duration) # Resample to sampling rate supported byu the models target_rate = 16000 signal = audresample.resample(signal, sampling_rate, target_rate) return process_func(signal, target_rate) def plot_expression(arousal, dominance, valence): r"""3D pixel plot of arousal, dominance, valence.""" # Voxels per dimension voxels = 7 # Create voxel grid x, y, z = np.indices((voxels + 1, voxels + 1, voxels + 1)) voxel = ( (x == round(arousal * voxels)) & (y == round(dominance * voxels)) & (z == round(valence * voxels)) ) colors = np.empty(voxel.shape, dtype=object) colors[voxel] = "#fcb06c" ax = plt.figure().add_subplot(projection='3d') ax.voxels(voxel, facecolors=colors, edgecolor='k') ax.set_aspect("equal") ax.set_xlim([0, voxels]) ax.set_ylim([0, voxels]) ax.set_zlim([0, voxels]) ax.set_xlabel("arousal", fontsize="large", labelpad=0) ax.set_ylabel("dominance", fontsize="large", labelpad=0) ax.set_zlabel("valence", fontsize="large", labelpad=0) ax.set_xticks( list(range(voxels + 1)), labels=["low", None, None, None, None, None, None, "high"], rotation=45, rotation_mode="anchor", verticalalignment="bottom", ) ax.set_yticks( list(range(voxels + 1)), labels=["low", None, None, None, None, None, None, "high"], rotation=-25, rotation_mode="anchor", verticalalignment="bottom", ) ax.set_zticks( list(range(voxels + 1)), labels=["low", None, None, None, None, None, None, "high"], rotation=25, rotation_mode="default", verticalalignment="bottom", ) description = ( "Recognize " f"[age](https://huggingface.co/{age_gender_model_name}), " f"[gender](https://huggingface.co/{age_gender_model_name}), " f"and [expression](https://huggingface.co/{expression_model_name}) " "of an audio file or microphone recording." ) with gr.Blocks() as demo: with gr.Tab(label="Speech analysis"): with gr.Row(): with gr.Column(): gr.Markdown(description) input = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Audio input", ) gr.Markdown("Only the first second of the audio is processed.") submit_btn = gr.Button(value="Submit") with gr.Column(): output_age = gr.Textbox(label="Age") output_gender = gr.Label(label="Gender") # output_expression = gr.Label(label="Expression") output_expression = gr.Image(label="Expression") outputs = [output_age, output_gender, output_expression] submit_btn.click(recognize, input, outputs) demo.launch(debug=True)