Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import matplotlib.pyplot as plt | |
import torch | |
from flamingo_mini_task.utils import load_url | |
from flamingo_mini_task import FlamingoModel, FlamingoProcessor | |
from datasets import load_dataset,concatenate_datasets | |
from PIL import Image | |
EXAMPLES_DIR = 'examples' | |
DEFAULT_PROMPT = "<image>" | |
MINI_MODEL = "flamingo-mini-bilbaocaptions-scienceQA[QA]" | |
TINY_MODEL = "flamingo-tiny-scienceQA[COT+QA]" | |
MEGATINY_MODEL = "flamingo-megatiny-opt-scienceQA[QA]" | |
flamingo_megatiny_captioning_models = { | |
MINI_MODEL: { | |
'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-tiny_ScienceQA_COT-QA'), | |
}, | |
TINY_MODEL: { | |
'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-mini-Bilbao_Captions-task_BilbaoQA-ScienceQA'), | |
}, | |
MEGATINY_MODEL:{ | |
'model': FlamingoModel.from_pretrained('landersanmi/flamingo-megatiny-opt-QA') | |
}, | |
} | |
# setup some example images | |
examples = [] | |
path = EXAMPLES_DIR + "/{}" | |
cot = False | |
examples.append([path.format("koala.png"), "What animal is this?", "Koala", "Elephant", "Cat", "Mouse", cot, MEGATINY_MODEL]) | |
examples.append([path.format("townhall.jpg"), "What building is this?", "Guggenheim museum", "San mames stadium", "Alhondiga", "Bilbao townhall", cot, TINY_MODEL]) | |
examples.append([path.format("muniain.jpeg"), "What team is IKer Muniain associated?", "Real Madrid", "Manchester United", "Athletic Bilbao", "Rayo Vallecano", cot, TINY_MODEL]) | |
examples.append([path.format("lasalve.jpeg"), "What is the name of this bridge?", "La Salve", "Zubizuri", "La Ribera", "San Anton", cot, TINY_MODEL]) | |
examples.append([path.format("athl.jpeg"), "Football fans hold flags with what team colors?", "Athletic", "Besiktas", "Udinese", "Real Madrid", cot, TINY_MODEL]) | |
#examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL]) | |
#examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL]) | |
def generate_text(image, question, option_a, option_b, option_c, option_d, cot_checkbox, model_name): | |
model = flamingo_megatiny_captioning_models[model_name]['model'] | |
processor = FlamingoProcessor(model.config) | |
prompt = "" | |
if cot_checkbox: | |
prompt += "[COT]" | |
else: | |
prompt += "[QA]" | |
prompt += "[CONTEXT]<image>[QUESTION]{} [OPTIONS] (A) {} (B) {} (C) {} (D) {} [ANSWER]".format(question, | |
option_a, | |
option_b, | |
option_c, | |
option_d) | |
print(prompt) | |
prediction = model.generate_captions(images = image, | |
processor = processor, | |
prompt = prompt, | |
) | |
return prediction[0].split('[ANSWER]')[1] | |
image_input = gr.Image(path.format("giraffe.jpeg")) | |
question_input = gr.inputs.Textbox(default="What animal is this?") | |
opt_a_input = gr.inputs.Textbox(default="Dog") | |
opt_b_input = gr.inputs.Textbox(default="Giraffe") | |
opt_c_input = gr.inputs.Textbox(default="Elephant") | |
opt_d_input = gr.inputs.Textbox(default="Cocodrile") | |
cot_checkbox = gr.inputs.Checkbox(label="Generate COT") | |
select_model = gr.inputs.Dropdown(choices=list(flamingo_megatiny_captioning_models.keys())) | |
text_output = gr.outputs.Textbox() | |
# Create the Gradio interface | |
gr.Interface( | |
fn=generate_text, | |
inputs=[image_input, | |
question_input, | |
opt_a_input, | |
opt_b_input, | |
opt_c_input, | |
opt_d_input, | |
cot_checkbox, | |
select_model | |
], | |
examples=examples, | |
outputs=text_output, | |
title='Generate answers from MCQ', | |
description='Generate answers from Multiple Choice Questions or generate a Chain Of Though about the question and the options given', | |
theme='default' | |
).launch() |