OneEncoder-retriever / text_image_audio.py
bilalfaye's picture
Update text_image_audio.py
e925821 verified
# PyTorch for deep learning operations
import torch
import torch.nn as nn
# PyTorch data loading and utilities
import torch.multiprocessing
import torchaudio
from transformers import AutoProcessor, Wav2Vec2Model
import torchaudio.transforms as transforms
from huggingface_hub import PyTorchModelHubMixin
from configs import CFG
from text_image import OneEncoder as TextImageEncoder
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio
class AlignmentLayer(nn.Module):
def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs):
super(AlignmentLayer, self).__init__(*args, **kwargs)
# Attributes
self.input_dim = input_dim
self.projection_dim = projection_dim
self.dropout_rate = dropout_rate
# Layers
self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim)
self.gelu = nn.GELU()
self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim)
self.dropout = nn.Dropout(self.dropout_rate)
self.normalization_layer = nn.LayerNorm(self.projection_dim)
def forward(self, inputs):
x = inputs
x = self.linear_layer1(x)
x = self.gelu(x)
x = self.linear_layer2(x)
x = self.dropout(x)
x = self.normalization_layer(x)
return x
def __call__(self, inputs):
return self.forward(inputs)
class AudioEncoder(nn.Module):
def __init__(self, model_name=CFG.audio_name, projection_dim=CFG.projection_dim,
trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs):
super(AudioEncoder, self).__init__(*args, **kwargs)
# Attributes
self.model_name = model_name
self.projection_dim = projection_dim
self.dropout_rate = dropout_rate
self.trainable = trainable
# Models
self.pretrained_encoder = Wav2Vec2Model.from_pretrained(self.model_name)
self.alignment_layer = AlignmentLayer(
input_dim=self.pretrained_encoder.config.hidden_size,
projection_dim=self.projection_dim,
dropout_rate=self.dropout_rate)
# Freeze Wav2VecModel
for parameter in self.pretrained_encoder.parameters():
parameter.requires_grad = self.trainable
# Unfreeze not initialized layers
newly_initialized_layers = [
'encoder.pos_conv_embed.conv.parametrizations.weight.original0',
'encoder.pos_conv_embed.conv.parametrizations.weight.original1',
'masked_spec_embed'
]
for name, param in self.pretrained_encoder.named_parameters():
if any(layer_name in name for layer_name in newly_initialized_layers):
param.requires_grad = True
def forward(self, inputs):
x = self.pretrained_encoder(inputs['input_values'].float()).last_hidden_state
x = self.alignment_layer(x)
return x
def __call__(self, inputs):
return self.forward(inputs)
class ModalityTokenEncoder(nn.Module):
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
# Attributes
self.projection_dim = projection_dim
self.device = device
self.token_size = token_size
self.token_dim = token_dim
# Models
audio_variance = torch.rand(1) * 0.5 + 0.1
self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
size=(self.token_size, self.token_dim)).to(self.device))
self.token_projection = nn.Sequential(
nn.Linear(self.token_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, self.projection_dim),
nn.LayerNorm(self.projection_dim)
)
def forward(self):
return self.token_projection(self.audio_token)
def __call__(self):
return self.forward()
class OneEncoder(nn.Module, PyTorchModelHubMixin):
def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(), checkpoint="bilalfaye/OneEncoder-text-image",
audio_processor=AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h"),
sample_rate=CFG.sample_rate, audio_encoder=AudioEncoder(), *args, **kwargs):
super(OneEncoder, self).__init__(*args, **kwargs)
self.device = device
self.checkpoint = checkpoint
self.modality_token_encoder = modality_token_encoder
self.modality_token_encoder.device = self.device
self.text_image_encoder = TextImageEncoder(device=self.device)
self.text_image_encoder.from_pretrained(self.checkpoint)
self.audio_processor = audio_processor
self.sample_rate = sample_rate
self.audio_encoder = audio_encoder
self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device))
# Freeze
for parameter in self.text_image_encoder.parameters():
parameter.requires_grad = False
def load_audio(self, audio_path):
waveform, original_sample_rate = torchaudio.load(audio_path)
# If the audio needs to be resampled
if original_sample_rate != self.sample_rate:
resampler = transforms.Resample(orig_freq=original_sample_rate, new_freq=self.sample_rate)
waveform = resampler(waveform)
# mono sound -> output shape: torch.Size(1, dim)
# Stereo sound -> output shape: torch.Size(2, dim)
# Surround sound -> output shape: torch.Size(n, dim)
return waveform
def process_audio(self, audios):
# audios: list of numpy array
x = self.audio_processor(audios, sampling_rate=self.sample_rate, return_tensors="pt", padding=True, max_length=15*self.sample_rate, truncation=True)
#x = self.audio_processor(audios, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
return x
def encode_audio(self, audios):
# audios: torch 2D (batch, dim)
audio_embeddings = self.audio_encoder(audios.to(self.device))
modality_token = self.modality_token_encoder()
audio_features = self.text_image_encoder.universal_projection_encoder([audio_embeddings, modality_token]).last_hidden_state
return audio_features.float()
def matching_image_audio(self, audios, image_paths=None, image_tensors=None,
normalize=True, top_k=None, strategy="similarity", temperature=0.0):
# audios is of shape {"input_values":torch.Size([N, dim])}
wav_features = torch.mean(self.encode_audio(audios), dim=1)
image_features = self.text_image_encoder.encode_image(image_paths=image_paths, image_tensors=image_tensors)
if normalize:
image_features = F.normalize(image_features, p=2, dim=-1)
wav_features = F.normalize(wav_features, p=2, dim=-1)
dot_similarities = (image_features @ wav_features.T) * torch.exp(torch.tensor(temperature).to(self.device))
if strategy == 'softmax':
dot_similarities = (float(audios["input_values"].shape[0]) * dot_similarities).softmax(dim=-1)
if top_k is not None:
top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1)
return top_probs, top_labels
else:
return dot_similarities, None
def matching_text_audio(self, audios, texts, normalize=True, top_k=None, strategy="similarity", temperature=0.0):
# audios is of shape {"input_values":torch.Size([N, dim])}
wav_features = torch.mean(self.encode_audio(audios), dim=1)
text_features = self.text_image_encoder.encode_text(texts=texts)
if normalize:
text_features = F.normalize(text_features, p=2, dim=-1)
wav_features = F.normalize(wav_features, p=2, dim=-1)
dot_similarities = (text_features @ wav_features.T) * torch.exp(torch.tensor(temperature).to(self.device))
if strategy == 'softmax':
dot_similarities = (float(audios["input_values"].shape[0]) * dot_similarities).softmax(dim=-1)
if top_k is not None:
top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1)
return top_probs, top_labels
else:
return dot_similarities, None
def image_retrieval(self, query, image_paths, image_embeddings=None, temperature=0.0, n=9, plot=False, display_audio=False):
# query is of shape {"input_values":torch.Size([1, dim])}
wav_embeddings = torch.mean(self.encode_audio(audios=query), dim=1)
if image_embeddings is None:
image_embeddings = self.text_image_encoder.encode_image(image_paths=image_paths)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
wav_embeddings_n = F.normalize(wav_embeddings, p=2, dim=-1)
dot_similarity = (wav_embeddings_n @ image_embeddings_n.T) * torch.exp(
torch.tensor(temperature).to(self.device))
if n > len(image_paths):
n = len(image_paths)
values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n)
if plot:
nrows = int(np.sqrt(n))
ncols = int(np.ceil(n / nrows))
matches = [image_paths[idx] for idx in indices]
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 20))
for match, ax in zip(matches, axes.flatten()):
image = self.text_image_encoder.load_image(f"{match}")
ax.imshow(image)
ax.axis("off")
plt.savefig("img.png")
#if display_audio:
# fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
#plt.show()
#return values, indices