|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
|
|
from transformers import GenerationConfig |
|
from dataset import process_idefics_listener_generation_input |
|
import pdb |
|
|
|
def filter_targets(logits, index_to_token): |
|
target_logits = logits[:, index_to_token] |
|
return target_logits |
|
|
|
class IdeficsJointInferenceModel(nn.Module): |
|
|
|
def __init__(self, listener_lambda, speaker_lambda, |
|
model=None, listener=None, speaker=None): |
|
super().__init__() |
|
self.l_lambda = listener_lambda |
|
self.s_lambda = speaker_lambda |
|
|
|
self.has_shared_parameters = model is not None |
|
if self.has_shared_parameters: |
|
self.model = model |
|
else: |
|
self.listener = listener |
|
self.speaker = speaker |
|
|
|
def forward(self, inf_mode, arguments): |
|
if inf_mode == "joint_comprehension": |
|
return self.comprehension_side(arguments) |
|
elif inf_mode == "joint_reranking": |
|
return self.reranking_side(arguments) |
|
elif inf_mode == "comprehension": |
|
return self.split_comprehension_forward(arguments) |
|
elif inf_mode == "split_reranking": |
|
return self.split_reranking_forward(arguments) |
|
elif inf_mode == "generation": |
|
return self.split_generation_forward(arguments) |
|
|
|
def get_listener(self): |
|
if self.has_shared_parameters: |
|
return self.model |
|
else: |
|
return self.listener |
|
|
|
def get_speaker(self): |
|
if self.has_shared_parameters: |
|
return self.model |
|
else: |
|
return self.speaker |
|
|
|
def get_image_embeddings(self, pixel_values, pixel_attention_mask, model): |
|
''' |
|
Get image embeddings to avoid repeated computation for images during joint inference. |
|
Adapted from the IDEFICS-2 source code. |
|
''' |
|
|
|
model = self.get_listener() if model == "listener" else self.get_speaker() |
|
if len(pixel_attention_mask.shape) == 5: |
|
pixel_attention_mask = pixel_attention_mask[:, 0].contiguous() |
|
|
|
|
|
batch_size, num_images, num_channels, height, width = pixel_values.shape |
|
pixel_values = pixel_values.to(dtype=model.dtype) |
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) |
|
|
|
|
|
nb_values_per_image = pixel_values.shape[1:].numel() |
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image |
|
pixel_values = pixel_values[real_images_inds].contiguous() |
|
|
|
|
|
pixel_attention_mask = pixel_attention_mask.view( |
|
batch_size * num_images, *pixel_attention_mask.shape[2:] |
|
) |
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() |
|
|
|
patch_size = model.model.config.vision_config.patch_size |
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) |
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) |
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() |
|
|
|
|
|
image_hidden_states = model.model.model.vision_model( |
|
pixel_values=pixel_values, |
|
patch_attention_mask=patch_attention_mask, |
|
).last_hidden_state |
|
|
|
|
|
image_hidden_states = model.model.model.connector( |
|
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) |
|
) |
|
|
|
return image_hidden_states |
|
|
|
def split_comprehension_side(self, input_tokens, attn_mask, images, image_attn_mask, index_to_token): |
|
''' |
|
Redundant with split_comprehension_forward except for the final computation. |
|
Used during deployment in ray_models.py. |
|
''' |
|
listener = self.get_listener() |
|
all_logits = listener( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
pixel_values=images, |
|
pixel_attention_mask=image_attn_mask |
|
)['logits'] |
|
target_logits = filter_targets(all_logits[:, -1], index_to_token) |
|
listener_log_probs = F.log_softmax(target_logits, dim=1) |
|
return listener_log_probs |
|
|
|
def split_comprehension_forward(self, arguments): |
|
input_tokens, attn_mask, images, image_attn_mask = arguments |
|
listener = self.get_listener() |
|
all_logits = listener( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
pixel_values=images, |
|
pixel_attention_mask=image_attn_mask |
|
)['logits'] |
|
return all_logits |
|
|
|
def split_generation_forward(self, arguments): |
|
input_tokens, attn_mask, images, image_attn_mask = arguments |
|
speaker = self.get_speaker() |
|
all_logits = speaker( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
pixel_values=images, |
|
pixel_attention_mask=image_attn_mask |
|
)['logits'] |
|
return all_logits |
|
|
|
def split_reranking_forward(self, arguments): |
|
images, input_tokens, attn_mask, image_attn_mask, target_tokens, target_mask = arguments |
|
|
|
|
|
image_embeddings = self.get_image_embeddings(images, image_attn_mask, "speaker") |
|
embed_shape = image_embeddings.shape |
|
B, mult = input_tokens.shape[:2] |
|
C = images.shape[1] |
|
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) |
|
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) |
|
|
|
annotation_mask = torch.zeros(B, mult, device=image_embeddings.device).bool() |
|
_, speaker_log_probs = self.reranking_speaker_side(image_embeddings, input_tokens, attn_mask, |
|
image_attn_mask, target_tokens, target_mask, |
|
annotation_mask) |
|
return speaker_log_probs |
|
|
|
def comprehension_side(self, arguments): |
|
images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, \ |
|
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label = arguments |
|
|
|
if self.has_shared_parameters: |
|
image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener") |
|
listener_log_probs = self.comprehension_listener_side( |
|
image_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token |
|
) |
|
|
|
speaker_log_probs = self.comprehension_speaker_side( |
|
image_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label |
|
) |
|
else: |
|
|
|
listener_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener") |
|
listener_log_probs = self.comprehension_listener_side( |
|
listener_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token |
|
) |
|
|
|
speaker_embeddings = self.get_image_embeddings(images, "speaker") |
|
speaker_log_probs = self.comprehension_speaker_side( |
|
speaker_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label |
|
) |
|
|
|
joint_log_probs = self.comprehension_reranking(listener_log_probs, speaker_log_probs) |
|
return listener_log_probs, speaker_log_probs, joint_log_probs |
|
|
|
def comprehension_listener_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask, |
|
index_to_token): |
|
listener = self.get_listener() |
|
all_logits = listener( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
image_hidden_states=image_encoder_embeddings, |
|
pixel_attention_mask=image_attn_mask |
|
)['logits'] |
|
|
|
target_logits = filter_targets(all_logits[:, -1], index_to_token) |
|
listener_log_probs = F.log_softmax(target_logits, dim=1) |
|
return listener_log_probs |
|
|
|
def comprehension_speaker_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask, |
|
target_mask, target_label): |
|
|
|
B, C = input_tokens.shape[:2] |
|
embed_shape = image_encoder_embeddings.shape |
|
image_encoder_embeddings = image_encoder_embeddings.view(B, C, *embed_shape[1:]) |
|
image_encoder_embeddings = image_encoder_embeddings.unsqueeze(1).repeat(1, C, 1, 1, 1).view(-1, *embed_shape[1:]) |
|
input_tokens = input_tokens.view(B*C, -1) |
|
attn_mask = attn_mask.view(B*C, -1) |
|
|
|
|
|
speaker = self.get_speaker() |
|
all_logits = speaker( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
image_hidden_states=image_encoder_embeddings, |
|
)['logits'] |
|
|
|
|
|
all_log_probs = F.log_softmax(all_logits, dim=2) |
|
target_label = target_label.view(B*C, -1).unsqueeze(2) |
|
target_mask = target_mask.view(B*C, -1) |
|
token_log_probs = torch.gather(all_log_probs, 2, target_label).squeeze(2) |
|
|
|
|
|
token_log_probs = token_log_probs * target_mask |
|
utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, C) |
|
|
|
return utterance_log_probs |
|
|
|
def comprehension_reranking(self, listener_log_probs, speaker_log_probs): |
|
rerank_weights = self.l_lambda * listener_log_probs + (1 - self.l_lambda) * speaker_log_probs |
|
rerank_denominator = torch.logsumexp(rerank_weights, dim=1).unsqueeze(1) |
|
rerank_log_distribution = rerank_weights - rerank_denominator |
|
return rerank_log_distribution |
|
|
|
def reranking_side(self, arguments): |
|
images, label, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_tokens, s_target_mask, \ |
|
l_input_tokens, l_attn_mask, l_image_attn_mask, \ |
|
index_to_token, annotation_mask = arguments |
|
|
|
|
|
if self.has_shared_parameters: |
|
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") |
|
embed_shape = image_embeddings.shape |
|
B, mult = s_input_tokens.shape[:2] |
|
C = images.shape[1] |
|
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) |
|
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) |
|
|
|
speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens, |
|
s_attn_mask, s_image_attn_mask, |
|
s_target_tokens, s_target_mask, |
|
annotation_mask) |
|
|
|
listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask, |
|
l_image_attn_mask, label, index_to_token, |
|
annotation_mask) |
|
else: |
|
|
|
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") |
|
embed_shape = image_embeddings.shape |
|
B, mult = s_input_tokens.shape[:2] |
|
C = images.shape[1] |
|
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) |
|
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) |
|
|
|
speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens, |
|
s_attn_mask, s_image_attn_mask, |
|
s_target_tokens, s_target_mask, |
|
annotation_mask) |
|
|
|
|
|
image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener") |
|
embed_shape = image_embeddings.shape |
|
B, mult = s_input_tokens.shape[:2] |
|
C = images.shape[1] |
|
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) |
|
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) |
|
|
|
listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask, |
|
l_image_attn_mask, label, index_to_token, annotation_mask) |
|
|
|
|
|
utterance_distribution = self.reranking_combination(speaker_log_probs, listener_log_probs) |
|
return speaker_logits, speaker_log_probs, listener_log_probs, utterance_distribution |
|
|
|
|
|
def reranking_speaker_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask, |
|
target_tokens, target_mask, annotation_mask): |
|
|
|
B, mult = input_tokens.shape[:2] |
|
input_tokens = input_tokens.view(B*mult, -1) |
|
attn_mask = attn_mask.view(B*mult, -1) |
|
target_tokens = target_tokens.view(B*mult, -1).unsqueeze(-1) |
|
target_mask = target_mask.view(B*mult, -1) |
|
|
|
|
|
speaker = self.get_speaker() |
|
all_logits = speaker( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
image_hidden_states=image_embeddings, |
|
)['logits'] |
|
|
|
|
|
all_log_probs = F.log_softmax(all_logits, dim=2) |
|
token_log_probs = torch.gather(all_log_probs, 2, target_tokens).squeeze(2) |
|
token_log_probs = token_log_probs * target_mask |
|
utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, mult) |
|
utterance_log_probs[annotation_mask] = float('-inf') |
|
|
|
return all_logits, utterance_log_probs |
|
|
|
def reranking_listener_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask, |
|
label, index_to_token, annotation_mask): |
|
|
|
B, mult = input_tokens.shape[:2] |
|
input_tokens = input_tokens.view(B*mult, -1) |
|
attn_mask = attn_mask.view(B*mult, -1) |
|
label = label.unsqueeze(1).repeat(1, mult).view(-1).unsqueeze(1) |
|
|
|
|
|
listener = self.get_listener() |
|
all_logits = listener( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
image_hidden_states=image_embeddings, |
|
)['logits'] |
|
|
|
target_logits = filter_targets(all_logits[:, -1], index_to_token) |
|
listener_log_probs = F.log_softmax(target_logits, dim=1) |
|
utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, mult) |
|
|
|
utterance_log_probs[annotation_mask] = float('-inf') |
|
|
|
return utterance_log_probs |
|
|
|
def reranking_combination(self, speaker_utterance_log_probs, listener_utterance_log_probs): |
|
weights = self.s_lambda * speaker_utterance_log_probs + (1-self.s_lambda) * listener_utterance_log_probs |
|
rerank_denominator = torch.logsumexp(weights, dim=1).unsqueeze(1) |
|
rerank_log_distribution = weights - rerank_denominator |
|
return rerank_log_distribution |
|
|
|
def split_generate(self, input_tokens, attn_mask, images, image_attn_mask, processor, |
|
max_steps=25, sampling_type="nucleus", temperature=1.0, |
|
top_k=40, top_p=0.9, repetition_penalty=1, num_samples=1): |
|
|
|
speaker = self.get_speaker() |
|
generation_config = GenerationConfig( |
|
max_new_tokens=max_steps, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_k=top_k, top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=num_samples, |
|
output_hidden_states=True, |
|
return_dict_in_generate=True |
|
) |
|
outputs = speaker.generate( |
|
input_ids=input_tokens, |
|
attention_mask=attn_mask, |
|
pixel_values=images, |
|
pixel_attention_mask=image_attn_mask, |
|
generation_config=generation_config, |
|
use_cache=True |
|
) |
|
|
|
|
|
B = input_tokens.shape[0] |
|
observed_steps = len(outputs['hidden_states']) |
|
filtered_seqs = [] |
|
for seq in outputs['sequences']: |
|
filtered_seqs.append(seq[-observed_steps:]) |
|
speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True) |
|
|
|
|
|
target_outputs = torch.stack(filtered_seqs, dim=0) |
|
target_mask = target_outputs != 0 |
|
final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) |
|
token_logits = speaker.lm_head(final_states) |
|
token_log_probs = F.log_softmax(token_logits, dim=2) |
|
token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2) |
|
|
|
|
|
if B == 1: |
|
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(num_samples) |
|
best_idx = torch.argmax(utterance_log_probs).item() |
|
return [speaker_outputs[best_idx]] |
|
else: |
|
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) |
|
best_indices = torch.argmax(utterance_log_probs, dim=1) |
|
choices = [] |
|
for i in range(B): |
|
curr_index = num_samples * i + best_indices[i].item() |
|
choices.append(speaker_outputs[curr_index]) |
|
return choices |
|
|
|
|
|
def generate(self, images, s_input_tokens, s_attn_mask, s_image_attn_mask, label, |
|
image_paths, processor, image_dir, index_to_token, |
|
max_steps=25, sampling_type="nucleus", temperature=1.0, top_k=40, |
|
top_p=0.9, repetition_penalty=1, num_samples=10): |
|
|
|
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") |
|
|
|
|
|
speaker_utterance_log_probs, speaker_utterances = self.generate_speaker_side(processor, images, s_input_tokens, |
|
s_attn_mask, s_image_attn_mask, max_steps, |
|
sampling_type, temperature, |
|
top_k, top_p, repetition_penalty, |
|
num_samples) |
|
|
|
|
|
listener_log_probs = self.generate_listener_side(image_embeddings, speaker_utterances, label, image_paths, processor, |
|
image_dir, index_to_token, num_samples) |
|
|
|
|
|
utterance_weights = self.s_lambda*speaker_utterance_log_probs + (1-self.s_lambda)*listener_log_probs |
|
chosen_indices = torch.argmax(utterance_weights, dim=1) |
|
choices = [] |
|
for i in range(speaker_utterance_log_probs.shape[0]): |
|
curr_index = num_samples * i + chosen_indices[i].item() |
|
choices.append(speaker_utterances[curr_index]) |
|
|
|
return choices, speaker_utterances, listener_log_probs, speaker_utterance_log_probs, utterance_weights |
|
|
|
def generate_speaker_side(self, processor, images, s_input_tokens, s_attn_mask, s_image_attn_mask, max_steps, |
|
sampling_type, temperature, top_k, top_p, repetition_penalty, num_samples): |
|
|
|
speaker = self.get_speaker() |
|
generation_config = GenerationConfig( |
|
max_new_tokens=max_steps, |
|
min_new_tokens=1, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_k=top_k, top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=num_samples, |
|
output_hidden_states=True, |
|
return_dict_in_generate=True |
|
) |
|
|
|
print(torch.any(torch.isnan(s_input_tokens))) |
|
print(torch.any(torch.isnan(s_attn_mask))) |
|
print(torch.any(torch.isnan(images))) |
|
print(torch.any(torch.isnan(s_image_attn_mask))) |
|
|
|
outputs = speaker.generate( |
|
input_ids=s_input_tokens, |
|
attention_mask=s_attn_mask, |
|
pixel_values=images, |
|
pixel_attention_mask=s_image_attn_mask, |
|
generation_config=generation_config, |
|
use_cache=True |
|
) |
|
|
|
|
|
B = s_input_tokens.shape[0] |
|
observed_steps = len(outputs['hidden_states']) |
|
filtered_seqs = [] |
|
for seq in outputs['sequences']: |
|
filtered_seqs.append(seq[-observed_steps:]) |
|
speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True) |
|
|
|
|
|
target_outputs = torch.stack(filtered_seqs, dim=0) |
|
target_mask = target_outputs != 0 |
|
final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) |
|
token_logits = speaker.lm_head(final_states) |
|
token_log_probs = F.log_softmax(token_logits, dim=2) |
|
token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2) |
|
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) |
|
|
|
return utterance_log_probs, speaker_outputs |
|
|
|
def generate_listener_side(self, image_embeddings, speaker_utterances, label, image_paths, processor, |
|
image_dir, index_to_token, num_samples): |
|
|
|
B = label.shape[0] |
|
embed_shape = image_embeddings.shape |
|
image_embeddings = image_embeddings.view(B, -1, *embed_shape[1:]) |
|
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, num_samples, 1, 1, 1).view(-1, *embed_shape[1:]) |
|
|
|
l_batch = process_idefics_listener_generation_input(image_paths, speaker_utterances, processor, |
|
image_dir, num_samples, image_embeddings.device) |
|
l_input_tokens, l_attn_mask, _, l_image_attn_mask = l_batch |
|
label = label.unsqueeze(1).repeat(1, num_samples).view(-1).unsqueeze(1) |
|
|
|
|
|
listener = self.get_listener() |
|
all_logits = listener( |
|
input_ids=l_input_tokens, |
|
attention_mask=l_attn_mask, |
|
image_hidden_states=image_embeddings, |
|
pixel_attention_mask=l_image_attn_mask |
|
)['logits'] |
|
|
|
target_logits = filter_targets(all_logits[:, -1], index_to_token) |
|
listener_log_probs = F.log_softmax(target_logits, dim=1) |
|
utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, num_samples) |
|
|
|
return utterance_log_probs |
|
|
|
|
|
|