Lander San Millan
feat: examples added
1114f49
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import torch
from flamingo_mini_task.utils import load_url
from flamingo_mini_task import FlamingoModel, FlamingoProcessor
from datasets import load_dataset,concatenate_datasets
from PIL import Image
EXAMPLES_DIR = 'examples'
DEFAULT_PROMPT = "<image>"
MINI_MODEL = "flamingo-mini-bilbaocaptions-scienceQA[QA]"
TINY_MODEL = "flamingo-tiny-scienceQA[COT+QA]"
MEGATINY_MODEL = "flamingo-megatiny-opt-scienceQA[QA]"
flamingo_megatiny_captioning_models = {
MINI_MODEL: {
'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-tiny_ScienceQA_COT-QA'),
},
TINY_MODEL: {
'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-mini-Bilbao_Captions-task_BilbaoQA-ScienceQA'),
},
MEGATINY_MODEL:{
'model': FlamingoModel.from_pretrained('landersanmi/flamingo-megatiny-opt-QA')
},
}
# setup some example images
examples = []
path = EXAMPLES_DIR + "/{}"
cot = False
examples.append([path.format("koala.png"), "What animal is this?", "Koala", "Elephant", "Cat", "Mouse", cot, MEGATINY_MODEL])
examples.append([path.format("townhall.jpg"), "What building is this?", "Guggenheim museum", "San mames stadium", "Alhondiga", "Bilbao townhall", cot, TINY_MODEL])
examples.append([path.format("muniain.jpeg"), "What team is IKer Muniain associated?", "Real Madrid", "Manchester United", "Athletic Bilbao", "Rayo Vallecano", cot, TINY_MODEL])
examples.append([path.format("lasalve.jpeg"), "What is the name of this bridge?", "La Salve", "Zubizuri", "La Ribera", "San Anton", cot, TINY_MODEL])
examples.append([path.format("athl.jpeg"), "Football fans hold flags with what team colors?", "Athletic", "Besiktas", "Udinese", "Real Madrid", cot, TINY_MODEL])
#examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL])
#examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL])
def generate_text(image, question, option_a, option_b, option_c, option_d, cot_checkbox, model_name):
model = flamingo_megatiny_captioning_models[model_name]['model']
processor = FlamingoProcessor(model.config)
prompt = ""
if cot_checkbox:
prompt += "[COT]"
else:
prompt += "[QA]"
prompt += "[CONTEXT]<image>[QUESTION]{} [OPTIONS] (A) {} (B) {} (C) {} (D) {} [ANSWER]".format(question,
option_a,
option_b,
option_c,
option_d)
print(prompt)
prediction = model.generate_captions(images = image,
processor = processor,
prompt = prompt,
)
return prediction[0].split('[ANSWER]')[1]
image_input = gr.Image(path.format("giraffe.jpeg"))
question_input = gr.inputs.Textbox(default="What animal is this?")
opt_a_input = gr.inputs.Textbox(default="Dog")
opt_b_input = gr.inputs.Textbox(default="Giraffe")
opt_c_input = gr.inputs.Textbox(default="Elephant")
opt_d_input = gr.inputs.Textbox(default="Cocodrile")
cot_checkbox = gr.inputs.Checkbox(label="Generate COT")
select_model = gr.inputs.Dropdown(choices=list(flamingo_megatiny_captioning_models.keys()))
text_output = gr.outputs.Textbox()
# Create the Gradio interface
gr.Interface(
fn=generate_text,
inputs=[image_input,
question_input,
opt_a_input,
opt_b_input,
opt_c_input,
opt_d_input,
cot_checkbox,
select_model
],
examples=examples,
outputs=text_output,
title='Generate answers from MCQ',
description='Generate answers from Multiple Choice Questions or generate a Chain Of Though about the question and the options given',
theme='default'
).launch()