Spaces:
Running
Running
import base64 | |
import io | |
import random | |
from textwrap import dedent | |
import gradio as gr | |
from openai import OpenAI | |
from PIL import Image | |
from pydantic import BaseModel | |
from prompts import ( | |
concept_generation_system_prompt, | |
data_processing_generation_system_prompt, | |
evaluator_system_prompt, | |
fusion_generation_system_prompt, | |
question_bias_generation_system_prompt, | |
reasoning_generation_system_prompt, | |
refine_system_prompt_concept, | |
refine_system_prompt_data, | |
refine_system_prompt_question_bias, | |
refine_system_prompt_reason, | |
refine_system_prompt_visual, | |
refiner_system_prompt, | |
review_system_prompt, | |
visual_interpretation_generation_system_prompt, | |
) | |
class Distractor(BaseModel): | |
text: str | |
reason: str | |
class Distractors(BaseModel): | |
distractors: list[Distractor] | |
class Comment(BaseModel): | |
option: str | |
comment: str | |
class CommentFormat(BaseModel): | |
comments: list[Comment] | |
class Judgement(BaseModel): | |
reasoning: str | |
correctness: int | |
improvement: str | |
class Question(BaseModel): | |
reasoning: str | |
distractors: list[str] | |
def base64_to_image(base64_str): | |
image_data = base64.b64decode(base64_str) | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
def get_reply(client, system_prompt, user_prompt, image_base64, output_format): | |
completion = client.beta.chat.completions.parse( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": dedent(system_prompt)}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": dedent(user_prompt)}, | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/png;base64,{image_base64}"}, | |
}, | |
], | |
}, | |
], | |
response_format=output_format, | |
# temperature=0, # Set to 0 for deterministic responses | |
) | |
parsed_output = completion.choices[0].message.parsed.dict() | |
return parsed_output | |
def convert_to_multi_choice(client, question, answer, image_base64, reviewer): | |
user_prompt = f""" | |
Question: {question} | |
Correct Answer: {answer} | |
""" | |
distractors_concept = get_reply( | |
client, concept_generation_system_prompt, user_prompt, image_base64, Distractors | |
)["distractors"] | |
distractors_reasoning = get_reply( | |
client, | |
reasoning_generation_system_prompt, | |
user_prompt, | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_visual_interpretation = get_reply( | |
client, | |
visual_interpretation_generation_system_prompt, | |
user_prompt, | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_data_processing = get_reply( | |
client, | |
data_processing_generation_system_prompt, | |
user_prompt, | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_question_bias = get_reply( | |
client, | |
question_bias_generation_system_prompt, | |
user_prompt, | |
image_base64, | |
Distractors, | |
)["distractors"] | |
# print(distractors_concept) | |
if reviewer: | |
user_prompt = """ | |
Question: {question} | |
Correct Answer: {answer} | |
Distractions and Reasonings: {distractors} | |
""" | |
reviews_concept = get_reply( | |
client, | |
review_system_prompt.format(type="conceptual"), | |
user_prompt.format( | |
question=question, answer=answer, distractors=distractors_concept | |
), | |
image_base64, | |
CommentFormat, | |
)["comments"] | |
reviews_reasoning = get_reply( | |
client, | |
review_system_prompt.format(type="reasoning"), | |
user_prompt.format( | |
question=question, answer=answer, distractors=distractors_reasoning | |
), | |
image_base64, | |
CommentFormat, | |
)["comments"] | |
reviews_visual_interpretation = get_reply( | |
client, | |
review_system_prompt.format(type="visual interpretation"), | |
user_prompt.format( | |
question=question, | |
answer=answer, | |
distractors=distractors_visual_interpretation, | |
), | |
image_base64, | |
CommentFormat, | |
)["comments"] | |
reviews_data_processing = get_reply( | |
client, | |
review_system_prompt.format(type="data processing"), | |
user_prompt.format( | |
question=question, | |
answer=answer, | |
distractors=distractors_data_processing, | |
), | |
image_base64, | |
CommentFormat, | |
)["comments"] | |
reviews_question_bias = get_reply( | |
client, | |
review_system_prompt.format(type="question bias"), | |
user_prompt.format( | |
question=question, answer=answer, distractors=distractors_question_bias | |
), | |
image_base64, | |
CommentFormat, | |
)["comments"] | |
# print(reviews_concept) | |
user_prompt = """ | |
Question: {question} | |
Correct Answer: {answer} | |
Distractions and Reviewer Comments: {reviews} | |
""" | |
distractors_concept = get_reply( | |
client, | |
refine_system_prompt_concept, | |
user_prompt.format( | |
question=question, answer=answer, reviews=reviews_concept | |
), | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_reasoning = get_reply( | |
client, | |
refine_system_prompt_reason, | |
user_prompt.format( | |
question=question, answer=answer, reviews=reviews_reasoning | |
), | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_visual_interpretation = get_reply( | |
client, | |
refine_system_prompt_visual, | |
user_prompt.format( | |
question=question, answer=answer, reviews=reviews_visual_interpretation | |
), | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_data_processing = get_reply( | |
client, | |
refine_system_prompt_data, | |
user_prompt.format( | |
question=question, answer=answer, reviews=reviews_data_processing | |
), | |
image_base64, | |
Distractors, | |
)["distractors"] | |
distractors_question_bias = get_reply( | |
client, | |
refine_system_prompt_question_bias, | |
user_prompt.format( | |
question=question, answer=answer, reviews=reviews_question_bias | |
), | |
image_base64, | |
Distractors, | |
)["distractors"] | |
# print(distractors_concept) | |
distractors = ( | |
distractors_concept | |
+ distractors_reasoning | |
+ distractors_visual_interpretation | |
+ distractors_data_processing | |
+ distractors_question_bias | |
) | |
user_prompt = f""" | |
Question: {question} | |
Correct Answer: {answer} | |
All Distractors: {distractors} | |
""" | |
distractors = get_reply( | |
client, fusion_generation_system_prompt, user_prompt, image_base64, Distractors | |
)["distractors"] | |
return distractors | |
def judge_multichoice_correctness_with_image( | |
client, question, choices, answer, image_base64 | |
): | |
user_prompt = f""" | |
Question: {question} | |
Choices: {choices} | |
Correct Answer: {answer} | |
""" | |
response = get_reply( | |
client, | |
evaluator_system_prompt, | |
user_prompt, | |
image_base64, | |
Judgement, | |
) | |
return response | |
def improve_multichoice_correctness_with_image( | |
client, | |
question, | |
choices, | |
answer, | |
issue, | |
improvement, | |
image_base64, | |
): | |
user_prompt = f""" | |
Question: {question} | |
Choices: {choices} | |
Correct Answer: {answer} | |
Identified Issues: {issue} | |
Suggested Improvements: {improvement} | |
""" | |
response = get_reply( | |
client, | |
refiner_system_prompt, | |
user_prompt, | |
image_base64, | |
Question, | |
) | |
return response | |
def process_one_question(api_key, image, question, answer, components): | |
reviewer = "Reviewer" in components | |
refiner = "Refiner" in components | |
pil_image = Image.fromarray(image) | |
buffer = io.BytesIO() | |
pil_image.save(buffer, format="PNG") | |
buffer.seek(0) | |
image_base64 = base64.b64encode(buffer.read()).decode("utf-8") | |
random.seed(1234) | |
client = OpenAI(api_key=api_key) | |
distactors = convert_to_multi_choice( | |
client, question, answer, image_base64, reviewer | |
) | |
choices = [item["text"] for item in distactors] + [answer] | |
random.shuffle(choices) | |
if refiner: | |
judgement = judge_multichoice_correctness_with_image( | |
client, question, choices, answer, image_base64 | |
) | |
distractors = improve_multichoice_correctness_with_image( | |
client, | |
question, | |
choices, | |
answer, | |
judgement["reasoning"], | |
judgement["improvement"], | |
image_base64, | |
) | |
choices = distractors["distractors"] + [answer] | |
random.shuffle(choices) | |
output = f"Question: {question}\n\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\n\nAnswer: {'ABCD'[choices.index(answer)]}" | |
return output | |
def main_gradio(): | |
interface = gr.Interface( | |
fn=process_one_question, | |
inputs=[ | |
gr.Textbox(label="OpenAI API Key"), | |
gr.Image(label="Upload an Image"), | |
gr.Textbox(label="Question"), | |
gr.Textbox(label="Answer"), | |
gr.CheckboxGroup(["Reviewer", "Refiner"], label="Components"), | |
], | |
outputs=gr.Textbox(label="Output"), | |
title="AutoConverter: Automated Generation of Challenging Multiple-Choice Questions for Vision Language Model Evaluation", | |
) | |
interface.launch() | |
if __name__ == "__main__": | |
main_gradio() | |