AutoConverter / main.py
yuhuizhang's picture
Upload folder using huggingface_hub
49e4fbb verified
import base64
import io
import random
from textwrap import dedent
import gradio as gr
from openai import OpenAI
from PIL import Image
from pydantic import BaseModel
from prompts import (
concept_generation_system_prompt,
data_processing_generation_system_prompt,
evaluator_system_prompt,
fusion_generation_system_prompt,
question_bias_generation_system_prompt,
reasoning_generation_system_prompt,
refine_system_prompt_concept,
refine_system_prompt_data,
refine_system_prompt_question_bias,
refine_system_prompt_reason,
refine_system_prompt_visual,
refiner_system_prompt,
review_system_prompt,
visual_interpretation_generation_system_prompt,
)
class Distractor(BaseModel):
text: str
reason: str
class Distractors(BaseModel):
distractors: list[Distractor]
class Comment(BaseModel):
option: str
comment: str
class CommentFormat(BaseModel):
comments: list[Comment]
class Judgement(BaseModel):
reasoning: str
correctness: int
improvement: str
class Question(BaseModel):
reasoning: str
distractors: list[str]
def base64_to_image(base64_str):
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
return image
def get_reply(client, system_prompt, user_prompt, image_base64, output_format):
completion = client.beta.chat.completions.parse(
model="gpt-4o",
messages=[
{"role": "system", "content": dedent(system_prompt)},
{
"role": "user",
"content": [
{"type": "text", "text": dedent(user_prompt)},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
},
],
},
],
response_format=output_format,
# temperature=0, # Set to 0 for deterministic responses
)
parsed_output = completion.choices[0].message.parsed.dict()
return parsed_output
def convert_to_multi_choice(client, question, answer, image_base64, reviewer):
user_prompt = f"""
Question: {question}
Correct Answer: {answer}
"""
distractors_concept = get_reply(
client, concept_generation_system_prompt, user_prompt, image_base64, Distractors
)["distractors"]
distractors_reasoning = get_reply(
client,
reasoning_generation_system_prompt,
user_prompt,
image_base64,
Distractors,
)["distractors"]
distractors_visual_interpretation = get_reply(
client,
visual_interpretation_generation_system_prompt,
user_prompt,
image_base64,
Distractors,
)["distractors"]
distractors_data_processing = get_reply(
client,
data_processing_generation_system_prompt,
user_prompt,
image_base64,
Distractors,
)["distractors"]
distractors_question_bias = get_reply(
client,
question_bias_generation_system_prompt,
user_prompt,
image_base64,
Distractors,
)["distractors"]
# print(distractors_concept)
if reviewer:
user_prompt = """
Question: {question}
Correct Answer: {answer}
Distractions and Reasonings: {distractors}
"""
reviews_concept = get_reply(
client,
review_system_prompt.format(type="conceptual"),
user_prompt.format(
question=question, answer=answer, distractors=distractors_concept
),
image_base64,
CommentFormat,
)["comments"]
reviews_reasoning = get_reply(
client,
review_system_prompt.format(type="reasoning"),
user_prompt.format(
question=question, answer=answer, distractors=distractors_reasoning
),
image_base64,
CommentFormat,
)["comments"]
reviews_visual_interpretation = get_reply(
client,
review_system_prompt.format(type="visual interpretation"),
user_prompt.format(
question=question,
answer=answer,
distractors=distractors_visual_interpretation,
),
image_base64,
CommentFormat,
)["comments"]
reviews_data_processing = get_reply(
client,
review_system_prompt.format(type="data processing"),
user_prompt.format(
question=question,
answer=answer,
distractors=distractors_data_processing,
),
image_base64,
CommentFormat,
)["comments"]
reviews_question_bias = get_reply(
client,
review_system_prompt.format(type="question bias"),
user_prompt.format(
question=question, answer=answer, distractors=distractors_question_bias
),
image_base64,
CommentFormat,
)["comments"]
# print(reviews_concept)
user_prompt = """
Question: {question}
Correct Answer: {answer}
Distractions and Reviewer Comments: {reviews}
"""
distractors_concept = get_reply(
client,
refine_system_prompt_concept,
user_prompt.format(
question=question, answer=answer, reviews=reviews_concept
),
image_base64,
Distractors,
)["distractors"]
distractors_reasoning = get_reply(
client,
refine_system_prompt_reason,
user_prompt.format(
question=question, answer=answer, reviews=reviews_reasoning
),
image_base64,
Distractors,
)["distractors"]
distractors_visual_interpretation = get_reply(
client,
refine_system_prompt_visual,
user_prompt.format(
question=question, answer=answer, reviews=reviews_visual_interpretation
),
image_base64,
Distractors,
)["distractors"]
distractors_data_processing = get_reply(
client,
refine_system_prompt_data,
user_prompt.format(
question=question, answer=answer, reviews=reviews_data_processing
),
image_base64,
Distractors,
)["distractors"]
distractors_question_bias = get_reply(
client,
refine_system_prompt_question_bias,
user_prompt.format(
question=question, answer=answer, reviews=reviews_question_bias
),
image_base64,
Distractors,
)["distractors"]
# print(distractors_concept)
distractors = (
distractors_concept
+ distractors_reasoning
+ distractors_visual_interpretation
+ distractors_data_processing
+ distractors_question_bias
)
user_prompt = f"""
Question: {question}
Correct Answer: {answer}
All Distractors: {distractors}
"""
distractors = get_reply(
client, fusion_generation_system_prompt, user_prompt, image_base64, Distractors
)["distractors"]
return distractors
def judge_multichoice_correctness_with_image(
client, question, choices, answer, image_base64
):
user_prompt = f"""
Question: {question}
Choices: {choices}
Correct Answer: {answer}
"""
response = get_reply(
client,
evaluator_system_prompt,
user_prompt,
image_base64,
Judgement,
)
return response
def improve_multichoice_correctness_with_image(
client,
question,
choices,
answer,
issue,
improvement,
image_base64,
):
user_prompt = f"""
Question: {question}
Choices: {choices}
Correct Answer: {answer}
Identified Issues: {issue}
Suggested Improvements: {improvement}
"""
response = get_reply(
client,
refiner_system_prompt,
user_prompt,
image_base64,
Question,
)
return response
def process_one_question(api_key, image, question, answer, components):
reviewer = "Reviewer" in components
refiner = "Refiner" in components
pil_image = Image.fromarray(image)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
image_base64 = base64.b64encode(buffer.read()).decode("utf-8")
random.seed(1234)
client = OpenAI(api_key=api_key)
distactors = convert_to_multi_choice(
client, question, answer, image_base64, reviewer
)
choices = [item["text"] for item in distactors] + [answer]
random.shuffle(choices)
if refiner:
judgement = judge_multichoice_correctness_with_image(
client, question, choices, answer, image_base64
)
distractors = improve_multichoice_correctness_with_image(
client,
question,
choices,
answer,
judgement["reasoning"],
judgement["improvement"],
image_base64,
)
choices = distractors["distractors"] + [answer]
random.shuffle(choices)
output = f"Question: {question}\n\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\n\nAnswer: {'ABCD'[choices.index(answer)]}"
return output
def main_gradio():
interface = gr.Interface(
fn=process_one_question,
inputs=[
gr.Textbox(label="OpenAI API Key"),
gr.Image(label="Upload an Image"),
gr.Textbox(label="Question"),
gr.Textbox(label="Answer"),
gr.CheckboxGroup(["Reviewer", "Refiner"], label="Components"),
],
outputs=gr.Textbox(label="Output"),
title="AutoConverter: Automated Generation of Challenging Multiple-Choice Questions for Vision Language Model Evaluation",
)
interface.launch()
if __name__ == "__main__":
main_gradio()