|
import spaces |
|
import gradio as gr |
|
import torch |
|
|
|
import random |
|
import os |
|
from typing import List, Tuple |
|
|
|
from config_generator import generate_complete_game |
|
from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token |
|
from models import get_model |
|
|
|
css=""" |
|
.radio-group .wrap { |
|
display: grid; |
|
grid-template-columns: repeat(5, 1fr); |
|
grid-template-rows: repeat(5, 1fr); |
|
width: 100%; |
|
height: 100% |
|
} |
|
""" |
|
|
|
def initialize_game() -> List[List[str]]: |
|
context_dicts = [generate_complete_game() for _ in range(2)] |
|
|
|
roles = ["speaker"] * 3 + ["listener"] * 3 |
|
speaker_images = [] |
|
listener_images = [] |
|
targets = [] |
|
|
|
for context_dict in context_dicts: |
|
for i in range(3): |
|
speaker_images.append(context_dict["speaker_context"]) |
|
listener_images.append(context_dict["listener_context"]) |
|
targets.append(context_dict["targets"][i]) |
|
|
|
return list(zip(speaker_images, listener_images, targets, roles)) |
|
|
|
def get_model_response( |
|
model, adapter_name, processor, index_to_token, role: str, |
|
image_paths: List[str], user_message: str = "", target_image: str = "" |
|
) -> str: |
|
if role == "speaker": |
|
img_dir = "tangram_pngs" |
|
print("Starting processing") |
|
input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input( |
|
processor, image_paths, target_image, model.get_listener().device |
|
) |
|
image_paths = [image_paths] |
|
print("Starting inference") |
|
captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, |
|
processor, img_dir, index_to_token, adapter_name) |
|
print("Done") |
|
response = captions[0] |
|
else: |
|
print("Starting processing") |
|
images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \ |
|
s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input( |
|
processor, image_paths, user_message, model.get_listener().device |
|
) |
|
|
|
print("Starting inference") |
|
response = get_listener_response( |
|
model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, |
|
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name |
|
) |
|
print("Done") |
|
|
|
return response |
|
|
|
@spaces.GPU(duration=20) |
|
def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name): |
|
model.model.set_adapter(adapter_name) |
|
model = model.cuda() |
|
with torch.no_grad(): |
|
captions, _, _, _, _ = model.generate( |
|
images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(), |
|
image_paths, processor, img_dir, index_to_token, |
|
max_steps=30, sampling_type="nucleus", temperature=0.7, |
|
top_k=50, top_p=1, repetition_penalty=1, num_samples=5 |
|
) |
|
return captions |
|
|
|
@spaces.GPU(duration=20) |
|
def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, |
|
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name): |
|
model.model.set_adapter(adapter_name) |
|
model = model.cuda() |
|
with torch.no_grad(): |
|
_, _, joint_log_probs = model.comprehension_side([ |
|
images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token, |
|
s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(), |
|
]) |
|
target_idx = joint_log_probs[0].argmax().item() |
|
response = image_paths[target_idx] |
|
return response |
|
|
|
def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]: |
|
image_role_pairs = initialize_game() |
|
conversation = [] |
|
turn = 0 |
|
num_correct = 0 |
|
human_role = None |
|
adapter_name = "initial" if model_iteration == "Initial System" else "final" |
|
internal_model = model |
|
|
|
for speaker_image, listener_image, target_image, model_role in image_role_pairs: |
|
acc_message = f"{num_correct}/{turn}" |
|
if model_role == "speaker": |
|
human_role = "Listener" |
|
turn += 1 |
|
turn_message = f"{turn}/6" |
|
human_context = listener_image |
|
model_context = speaker_image |
|
target_idx = human_context.index(target_image) |
|
|
|
conversation.extend([ |
|
f"TURN: {turn}/6", |
|
f"Guess the target image given the speaker's description. ", |
|
]) |
|
model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, target_image=target_image) |
|
conversation.append(f"Model: {model_message}") |
|
conversation.append("You: The target is Image ") |
|
user_message = yield human_context, conversation, human_role, turn_message, acc_message |
|
|
|
conversation[-1] += f"{user_message}" |
|
if int(user_message) == target_idx + 1: |
|
conversation.append("Correct!\n") |
|
num_correct += 1 |
|
else: |
|
conversation.append(f"Incorrect!\n") |
|
else: |
|
|
|
human_role = "Speaker" |
|
turn += 1 |
|
turn_message = f"{turn}/6" |
|
human_context = speaker_image |
|
model_context = listener_image |
|
target_idx = human_context.index(target_image) |
|
|
|
conversation.extend([ |
|
f"TURN: {turn}/6", |
|
f"Generate a description for the target image. Your target is Image {target_idx + 1}", |
|
]) |
|
|
|
user_message = yield human_context, conversation, human_role, turn_message, acc_message |
|
conversation.append(f"You: {user_message}") |
|
model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, user_message=user_message) |
|
model_idx = human_context.index(model_message) |
|
|
|
if int(model_idx) == int(target_idx): |
|
conversation.append("The model guessed correctly!\n") |
|
num_correct += 1 |
|
else: |
|
conversation.append(f"The model guessed incorrectly.\n") |
|
|
|
acc_message = f"{num_correct}/{turn}" |
|
conversation.append("The game is over!") |
|
yield human_context, conversation, human_role, turn_message, acc_message |
|
|
|
def create_app(): |
|
with gr.Blocks(css=css) as app: |
|
gr.Markdown("# Tangram Reference Game") |
|
gr.Markdown( |
|
'### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\ |
|
'you wish to play against our initial trained model ("Initial System") or our model at the end of deployment ("Final System") ' +\ |
|
'and press the "Start Game" button. There will be 6 rounds of reference games. You will take on a "listener" or a "speaker" role at each round.' |
|
) |
|
|
|
gr.Markdown( |
|
'### In the speaker role, you will be assigned a target image. Your goal will be to describe this image (via a message in the textbox) ' +\ |
|
'so that your partner can guess what it is.' |
|
) |
|
gr.Markdown( |
|
'### In the listener role, you will be given a description. Your goal will be ' +\ |
|
'to select the image that the description best describes (by clicking on the relevant button).' |
|
) |
|
gr.Markdown( |
|
'### Press "Send" to submit your action in either role and make the game proceed.' |
|
) |
|
|
|
with gr.Row(): |
|
model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration") |
|
start_btn = gr.Button("Start Game") |
|
|
|
with gr.Row(): |
|
current_role = gr.Textbox(label="YOUR ROLE") |
|
current_turn = gr.Textbox(label="TURN") |
|
accuracy = gr.Textbox(label="FINAL ACCURACY") |
|
|
|
with gr.Row(): |
|
image_output = gr.Gallery( |
|
label="CONTEXT", show_label=False, elem_id="gallery", |
|
columns=5, rows=2, object_fit="contain", height="250px", |
|
allow_preview=False, container=True |
|
) |
|
|
|
with gr.Row(): |
|
conversation_output = gr.Textbox(label="Interaction History") |
|
|
|
with gr.Column(): |
|
user_input = gr.Textbox(label="Your Message as Speaker", interactive=False) |
|
radio_buttons = gr.Radio( |
|
label="Your Guess as Listener", |
|
elem_classes="radio-group", |
|
choices=list(range(1, 11)), |
|
interactive=False, |
|
) |
|
|
|
send_btn = gr.Button("Send") |
|
|
|
interaction_generator = None |
|
model = get_model() |
|
processor = get_processor() |
|
index_to_token = get_index_to_token() |
|
|
|
print("Heyo!") |
|
def start_interaction(model_iteration): |
|
if model_iteration is None: |
|
return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \ |
|
gr.update(interactive=False), gr.update(interactive=False) |
|
|
|
nonlocal interaction_generator |
|
nonlocal model |
|
nonlocal processor |
|
nonlocal index_to_token |
|
interaction_generator = interaction(model, processor, index_to_token, model_iteration) |
|
images, conversation, role, turn, acc_message = next(interaction_generator) |
|
human_listener = role == "Listener" |
|
return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \ |
|
gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True) |
|
|
|
def send_message(message, radio_choice): |
|
nonlocal interaction_generator |
|
if interaction_generator is None: |
|
return [], "Please start the interaction first.", "", gr.update(interactive=False), gr.update(interactive=False, value=None) |
|
|
|
try: |
|
user_output = message if radio_choice is None else radio_choice |
|
images, conversation, role, turn, acc_message = interaction_generator.send(user_output) |
|
human_listener = role == "Listener" |
|
return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \ |
|
gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), gr.update(interactive=True) |
|
except StopIteration: |
|
return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) |
|
|
|
start_btn.click( |
|
start_interaction, |
|
inputs=[model_iteration], |
|
outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn] |
|
) |
|
send_btn.click(send_message, inputs=[user_input, radio_buttons], outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn]) |
|
|
|
return app |
|
|
|
app = create_app() |
|
app.launch() |
|
|