|
from transformers import AutoProcessor |
|
from PIL import Image |
|
import os |
|
import torch |
|
import pickle |
|
|
|
|
|
BASE_SPEAKER_LEN = 787 |
|
|
|
def joint_listener_input(processor, context_images, description, device): |
|
|
|
img_dir = "tangram_pngs" |
|
raw_images = process_images(img_dir, context_images) |
|
target_anno = description.lower() |
|
prompt = construct_listener_full_prompt( |
|
processor, target_anno, 0, "verbose_instruction" |
|
) |
|
|
|
|
|
outputs = processor( |
|
text=[prompt], |
|
images=[raw_images], |
|
return_tensors="pt" |
|
).to(device) |
|
l_input_tokens = outputs['input_ids'][:, :-2] |
|
l_attn_mask = outputs['attention_mask'][:, :-2] |
|
l_attn_mask[(l_input_tokens == 0).bool()] = 0 |
|
images = outputs['pixel_values'] |
|
l_image_attn_mask = outputs['pixel_attention_mask'] |
|
|
|
|
|
prompts = [] |
|
for i in range(10): |
|
prompt = construct_speaker_full_prompt(processor, description, i, "information_after") |
|
prompts.append(prompt) |
|
outputs = processor( |
|
text=prompts, |
|
images=[raw_images]*10, |
|
padding='longest', |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
s_input_tokens = outputs['input_ids'][:, :-1] |
|
s_attn_mask = outputs['attention_mask'][:, :-1] |
|
s_attn_mask[(s_input_tokens == 0).bool()] = 0 |
|
s_image_attn_mask = outputs['pixel_attention_mask'] |
|
s_target_tokens = outputs['input_ids'][:, 1:] |
|
s_target_mask = [] |
|
for i in range(10): |
|
curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i]) |
|
s_target_mask.append(curr_mask) |
|
s_target_mask = torch.stack(s_target_mask, dim=0) |
|
|
|
return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \ |
|
s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \ |
|
s_target_tokens.unsqueeze(0) |
|
|
|
def joint_speaker_input(processor, image_paths, target_path, device): |
|
|
|
img_dir = "tangram_pngs" |
|
raw_images = process_images(img_dir, image_paths) |
|
target_idx = image_paths.index(target_path) |
|
base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True) |
|
|
|
|
|
outputs = processor( |
|
text=[base_prompt], |
|
images=[raw_images], |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
input_tokens = outputs['input_ids'] |
|
attn_mask = outputs['attention_mask'] |
|
attn_mask[(input_tokens == 0).bool()] = 0 |
|
images = outputs['pixel_values'] |
|
image_attn_mask = outputs['pixel_attention_mask'] |
|
|
|
return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device) |
|
|
|
|
|
|
|
|
|
def get_processor(): |
|
checkpoint = "HuggingFaceM4/idefics2-8b" |
|
processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False, |
|
size={"longest_edge": 448, "shortest_edge": 224}) |
|
return processor |
|
|
|
def get_index_to_token(): |
|
index_to_token_path = "index_to_token.pkl" |
|
with open(index_to_token_path, 'rb') as f: |
|
index_to_token = pickle.load(f) |
|
return index_to_token |
|
|
|
def process_images(img_dir, context_images): |
|
raw_images = [] |
|
for img in context_images: |
|
image_path = os.path.join(img_dir, img) |
|
raw_image = Image.open(image_path).convert('RGB') |
|
raw_images.append(raw_image) |
|
return raw_images |
|
|
|
def create_speaker_caption_mask(all_token_ids, text_mask): |
|
|
|
padding_tokens = torch.sum(all_token_ids == 0).item() |
|
caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN) |
|
|
|
|
|
target_mask = torch.zeros_like(text_mask) |
|
target_mask[-caption_tokens:] = 1 |
|
|
|
return target_mask.bool() |
|
|
|
def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"): |
|
target_anno = target_anno.lower().strip() |
|
messages = [] |
|
|
|
if comprehension_prompt_type == "verbose_instruction": |
|
|
|
messages.append( |
|
{ |
|
"role" : "user", |
|
"content" : [ |
|
{"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "}, |
|
{"type" : "text", "text" : "Your task is to guess which image the caption describes. "}, |
|
] |
|
} |
|
) |
|
|
|
|
|
for i in range(10): |
|
if i == 0: |
|
messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) |
|
else: |
|
messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) |
|
messages[0]["content"].append({"type" : "image"}) |
|
|
|
|
|
messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"}) |
|
messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"}) |
|
|
|
|
|
messages.append( |
|
{ |
|
"role" : "assistant", |
|
"content" : [ |
|
{"type" : "text", "text" : f"The caption describes Image {target_idx}"} |
|
] |
|
} |
|
) |
|
else: |
|
assert(False) |
|
|
|
return processor.apply_chat_template(messages, add_generation_prompt=False).strip() |
|
|
|
def construct_speaker_full_prompt(processor, target_anno, target_idx, |
|
generation_prompt_type="information_after"): |
|
messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type) |
|
|
|
|
|
target_anno = target_anno.lower().strip() |
|
messages.append( |
|
{ |
|
"role" : "assistant", |
|
"content" : [ |
|
{"type" : "text", "text" : target_anno} |
|
] |
|
} |
|
) |
|
|
|
return processor.apply_chat_template(messages, add_generation_prompt=False).strip() |
|
|
|
def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False): |
|
messages = [] |
|
|
|
if generation_prompt_type == "information_after": |
|
|
|
messages.append( |
|
{ |
|
"role" : "user", |
|
"content" : [ |
|
{"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "}, |
|
{"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "}, |
|
] |
|
} |
|
) |
|
|
|
|
|
for i in range(10): |
|
if i == 0: |
|
messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) |
|
else: |
|
messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) |
|
messages[0]["content"].append({"type" : "image"}) |
|
|
|
|
|
messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."}) |
|
else: |
|
assert(False) |
|
|
|
if process: |
|
prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip() |
|
return prompt |
|
else: |
|
return messages |
|
|
|
def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device): |
|
|
|
prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor) |
|
|
|
|
|
listener_inputs = processor( |
|
text=prompts, |
|
images=raw_images, |
|
padding='longest', |
|
return_tensors='pt' |
|
) |
|
|
|
input_tokens = listener_inputs['input_ids'][:, :-2].to(device) |
|
attn_mask = listener_inputs['attention_mask'][:, :-2].to(device) |
|
attn_mask[input_tokens == 0] = 0 |
|
images = listener_inputs['pixel_values'].to(device) |
|
image_attn_mask = listener_inputs['pixel_attention_mask'].to(device) |
|
|
|
return input_tokens, attn_mask, images, image_attn_mask |
|
|
|
def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor): |
|
prompts = [] |
|
all_raw_images = [] |
|
|
|
for i, speaker_context in enumerate(speaker_contexts): |
|
raw_images = process_images(img_dir, speaker_context) |
|
for j in range(num_samples): |
|
curr_idx = i * num_samples + j |
|
caption = captions[curr_idx] |
|
prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction") |
|
|
|
prompts.append(prompt) |
|
all_raw_images.append(raw_images) |
|
return prompts, all_raw_images |
|
|
|
|
|
|