Spaces:
Running
Running
# PyTorch for deep learning operations | |
import torch | |
import torch.nn as nn | |
# PyTorch data loading and utilities | |
import torch.multiprocessing | |
import torchaudio | |
from transformers import AutoProcessor, Wav2Vec2Model | |
import torchaudio.transforms as transforms | |
from huggingface_hub import PyTorchModelHubMixin | |
from configs import CFG | |
from text_image import OneEncoder as TextImageEncoder | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from IPython.display import Audio | |
class AlignmentLayer(nn.Module): | |
def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
super(AlignmentLayer, self).__init__(*args, **kwargs) | |
# Attributes | |
self.input_dim = input_dim | |
self.projection_dim = projection_dim | |
self.dropout_rate = dropout_rate | |
# Layers | |
self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) | |
self.gelu = nn.GELU() | |
self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) | |
self.dropout = nn.Dropout(self.dropout_rate) | |
self.normalization_layer = nn.LayerNorm(self.projection_dim) | |
def forward(self, inputs): | |
x = inputs | |
x = self.linear_layer1(x) | |
x = self.gelu(x) | |
x = self.linear_layer2(x) | |
x = self.dropout(x) | |
x = self.normalization_layer(x) | |
return x | |
def __call__(self, inputs): | |
return self.forward(inputs) | |
class AudioEncoder(nn.Module): | |
def __init__(self, model_name=CFG.audio_name, projection_dim=CFG.projection_dim, | |
trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
super(AudioEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.model_name = model_name | |
self.projection_dim = projection_dim | |
self.dropout_rate = dropout_rate | |
self.trainable = trainable | |
# Models | |
self.pretrained_encoder = Wav2Vec2Model.from_pretrained(self.model_name) | |
self.alignment_layer = AlignmentLayer( | |
input_dim=self.pretrained_encoder.config.hidden_size, | |
projection_dim=self.projection_dim, | |
dropout_rate=self.dropout_rate) | |
# Freeze Wav2VecModel | |
for parameter in self.pretrained_encoder.parameters(): | |
parameter.requires_grad = self.trainable | |
# Unfreeze not initialized layers | |
newly_initialized_layers = [ | |
'encoder.pos_conv_embed.conv.parametrizations.weight.original0', | |
'encoder.pos_conv_embed.conv.parametrizations.weight.original1', | |
'masked_spec_embed' | |
] | |
for name, param in self.pretrained_encoder.named_parameters(): | |
if any(layer_name in name for layer_name in newly_initialized_layers): | |
param.requires_grad = True | |
def forward(self, inputs): | |
x = self.pretrained_encoder(inputs['input_values'].float()).last_hidden_state | |
x = self.alignment_layer(x) | |
return x | |
def __call__(self, inputs): | |
return self.forward(inputs) | |
class ModalityTokenEncoder(nn.Module): | |
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs): | |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs) | |
# Attributes | |
self.projection_dim = projection_dim | |
self.device = device | |
self.token_size = token_size | |
self.token_dim = token_dim | |
# Models | |
audio_variance = torch.rand(1) * 0.5 + 0.1 | |
self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(), | |
size=(self.token_size, self.token_dim)).to(self.device)) | |
self.token_projection = nn.Sequential( | |
nn.Linear(self.token_dim, 64), | |
nn.ReLU(), | |
nn.Linear(64, 128), | |
nn.ReLU(), | |
nn.Linear(128, self.projection_dim), | |
nn.LayerNorm(self.projection_dim) | |
) | |
def forward(self): | |
return self.token_projection(self.audio_token) | |
def __call__(self): | |
return self.forward() | |
class OneEncoder(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(), checkpoint="bilalfaye/OneEncoder-text-image", | |
audio_processor=AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h"), | |
sample_rate=CFG.sample_rate, audio_encoder=AudioEncoder(), *args, **kwargs): | |
super(OneEncoder, self).__init__(*args, **kwargs) | |
self.device = device | |
self.checkpoint = checkpoint | |
self.modality_token_encoder = modality_token_encoder | |
self.modality_token_encoder.device = self.device | |
self.text_image_encoder = TextImageEncoder(device=self.device) | |
self.text_image_encoder.from_pretrained(self.checkpoint) | |
self.audio_processor = audio_processor | |
self.sample_rate = sample_rate | |
self.audio_encoder = audio_encoder | |
self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) | |
# Freeze | |
for parameter in self.text_image_encoder.parameters(): | |
parameter.requires_grad = False | |
def load_audio(self, audio_path): | |
waveform, original_sample_rate = torchaudio.load(audio_path) | |
# If the audio needs to be resampled | |
if original_sample_rate != self.sample_rate: | |
resampler = transforms.Resample(orig_freq=original_sample_rate, new_freq=self.sample_rate) | |
waveform = resampler(waveform) | |
# mono sound -> output shape: torch.Size(1, dim) | |
# Stereo sound -> output shape: torch.Size(2, dim) | |
# Surround sound -> output shape: torch.Size(n, dim) | |
return waveform | |
def process_audio(self, audios): | |
# audios: list of numpy array | |
x = self.audio_processor(audios, sampling_rate=self.sample_rate, return_tensors="pt", padding=True, max_length=15*self.sample_rate, truncation=True) | |
#x = self.audio_processor(audios, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) | |
return x | |
def encode_audio(self, audios): | |
# audios: torch 2D (batch, dim) | |
audio_embeddings = self.audio_encoder(audios.to(self.device)) | |
modality_token = self.modality_token_encoder() | |
audio_features = self.text_image_encoder.universal_projection_encoder([audio_embeddings, modality_token]).last_hidden_state | |
return audio_features.float() | |
def matching_image_audio(self, audios, image_paths=None, image_tensors=None, | |
normalize=True, top_k=None, strategy="similarity", temperature=0.0): | |
# audios is of shape {"input_values":torch.Size([N, dim])} | |
wav_features = torch.mean(self.encode_audio(audios), dim=1) | |
image_features = self.text_image_encoder.encode_image(image_paths=image_paths, image_tensors=image_tensors) | |
if normalize: | |
image_features = F.normalize(image_features, p=2, dim=-1) | |
wav_features = F.normalize(wav_features, p=2, dim=-1) | |
dot_similarities = (image_features @ wav_features.T) * torch.exp(torch.tensor(temperature).to(self.device)) | |
if strategy == 'softmax': | |
dot_similarities = (float(audios["input_values"].shape[0]) * dot_similarities).softmax(dim=-1) | |
if top_k is not None: | |
top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1) | |
return top_probs, top_labels | |
else: | |
return dot_similarities, None | |
def matching_text_audio(self, audios, texts, normalize=True, top_k=None, strategy="similarity", temperature=0.0): | |
# audios is of shape {"input_values":torch.Size([N, dim])} | |
wav_features = torch.mean(self.encode_audio(audios), dim=1) | |
text_features = self.text_image_encoder.encode_text(texts=texts) | |
if normalize: | |
text_features = F.normalize(text_features, p=2, dim=-1) | |
wav_features = F.normalize(wav_features, p=2, dim=-1) | |
dot_similarities = (text_features @ wav_features.T) * torch.exp(torch.tensor(temperature).to(self.device)) | |
if strategy == 'softmax': | |
dot_similarities = (float(audios["input_values"].shape[0]) * dot_similarities).softmax(dim=-1) | |
if top_k is not None: | |
top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1) | |
return top_probs, top_labels | |
else: | |
return dot_similarities, None | |
def image_retrieval(self, query, image_paths, image_embeddings=None, temperature=0.0, n=9, plot=False, display_audio=False): | |
# query is of shape {"input_values":torch.Size([1, dim])} | |
wav_embeddings = torch.mean(self.encode_audio(audios=query), dim=1) | |
if image_embeddings is None: | |
image_embeddings = self.text_image_encoder.encode_image(image_paths=image_paths) | |
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
wav_embeddings_n = F.normalize(wav_embeddings, p=2, dim=-1) | |
dot_similarity = (wav_embeddings_n @ image_embeddings_n.T) * torch.exp( | |
torch.tensor(temperature).to(self.device)) | |
if n > len(image_paths): | |
n = len(image_paths) | |
values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n) | |
if plot: | |
nrows = int(np.sqrt(n)) | |
ncols = int(np.ceil(n / nrows)) | |
matches = [image_paths[idx] for idx in indices] | |
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 20)) | |
for match, ax in zip(matches, axes.flatten()): | |
image = self.text_image_encoder.load_image(f"{match}") | |
ax.imshow(image) | |
ax.axis("off") | |
plt.savefig("img.png") | |
#if display_audio: | |
# fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate))) | |
#plt.show() | |
#return values, indices | |