import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import matplotlib.pyplot as plt import torch from flamingo_mini_task.utils import load_url from flamingo_mini_task import FlamingoModel, FlamingoProcessor from datasets import load_dataset,concatenate_datasets from PIL import Image EXAMPLES_DIR = 'examples' DEFAULT_PROMPT = "" MINI_MODEL = "flamingo-mini-bilbaocaptions-scienceQA[QA]" TINY_MODEL = "flamingo-tiny-scienceQA[COT+QA]" MEGATINY_MODEL = "flamingo-megatiny-opt-scienceQA[QA]" flamingo_megatiny_captioning_models = { MINI_MODEL: { 'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-tiny_ScienceQA_COT-QA'), }, TINY_MODEL: { 'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-mini-Bilbao_Captions-task_BilbaoQA-ScienceQA'), }, MEGATINY_MODEL:{ 'model': FlamingoModel.from_pretrained('landersanmi/flamingo-megatiny-opt-QA') }, } # setup some example images examples = [] path = EXAMPLES_DIR + "/{}" cot = False examples.append([path.format("koala.png"), "What animal is this?", "Koala", "Elephant", "Cat", "Mouse", cot, MEGATINY_MODEL]) examples.append([path.format("townhall.jpg"), "What building is this?", "Guggenheim museum", "San mames stadium", "Alhondiga", "Bilbao townhall", cot, TINY_MODEL]) examples.append([path.format("muniain.jpeg"), "What team is IKer Muniain associated?", "Real Madrid", "Manchester United", "Athletic Bilbao", "Rayo Vallecano", cot, TINY_MODEL]) examples.append([path.format("lasalve.jpeg"), "What is the name of this bridge?", "La Salve", "Zubizuri", "La Ribera", "San Anton", cot, TINY_MODEL]) examples.append([path.format("athl.jpeg"), "Football fans hold flags with what team colors?", "Athletic", "Besiktas", "Udinese", "Real Madrid", cot, TINY_MODEL]) #examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL]) #examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL]) def generate_text(image, question, option_a, option_b, option_c, option_d, cot_checkbox, model_name): model = flamingo_megatiny_captioning_models[model_name]['model'] processor = FlamingoProcessor(model.config) prompt = "" if cot_checkbox: prompt += "[COT]" else: prompt += "[QA]" prompt += "[CONTEXT][QUESTION]{} [OPTIONS] (A) {} (B) {} (C) {} (D) {} [ANSWER]".format(question, option_a, option_b, option_c, option_d) print(prompt) prediction = model.generate_captions(images = image, processor = processor, prompt = prompt, ) return prediction[0].split('[ANSWER]')[1] image_input = gr.Image(path.format("giraffe.jpeg")) question_input = gr.inputs.Textbox(default="What animal is this?") opt_a_input = gr.inputs.Textbox(default="Dog") opt_b_input = gr.inputs.Textbox(default="Giraffe") opt_c_input = gr.inputs.Textbox(default="Elephant") opt_d_input = gr.inputs.Textbox(default="Cocodrile") cot_checkbox = gr.inputs.Checkbox(label="Generate COT") select_model = gr.inputs.Dropdown(choices=list(flamingo_megatiny_captioning_models.keys())) text_output = gr.outputs.Textbox() # Create the Gradio interface gr.Interface( fn=generate_text, inputs=[image_input, question_input, opt_a_input, opt_b_input, opt_c_input, opt_d_input, cot_checkbox, select_model ], examples=examples, outputs=text_output, title='Generate answers from MCQ', description='Generate answers from Multiple Choice Questions or generate a Chain Of Though about the question and the options given', theme='default' ).launch()