|
import torch
|
|
from torch import nn
|
|
from transformers import LlavaForConditionalGeneration, LlavaConfig
|
|
import re
|
|
|
|
from PIL import Image
|
|
from random import randint
|
|
|
|
class VQApair(LlavaForConditionalGeneration):
|
|
config_class = LlavaConfig
|
|
def __init__(self, config, **kwargs):
|
|
super().__init__(config)
|
|
self.processor = kwargs.pop("proc")
|
|
|
|
def genChoice(self, question, base_prompt, img_obj):
|
|
base_prompt += "{}<|end|>\n<|user|> Suggest 1 correct answer<|end|><|assistant|> ".format(question)
|
|
inputs = self.processor(base_prompt, img_obj, return_tensors='pt').to(0)
|
|
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500)
|
|
index = torch.where(output[0]==32001)[0][-1].item()
|
|
answer = self.processor.decode(output[0][index:], skip_special_tokens=True)
|
|
base_prompt += "{}<|end|>\n<|user|> Suggest 3 incorrect answers<|end|><|assistant|> ".format(answer)
|
|
inputs = self.processor(base_prompt, img_obj, return_tensors='pt').to(0)
|
|
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500)
|
|
index = torch.where(output[0]==32001)[0][-1].item()
|
|
choices = self.processor.decode(output[0][index:], skip_special_tokens=True)
|
|
a = choices.split("\n")
|
|
a = [x[3:].strip() for x in a]
|
|
a = [x for x in a if x]
|
|
correct_answer = randint(0,len(a))
|
|
a.insert(correct_answer, answer)
|
|
|
|
a = ["{}) {}".format(i+1, a[i]) for i in range(len(a))]
|
|
ans = "Correct Answer: {}".format(a[correct_answer])
|
|
return {"Choices": a, "Answers": ans}
|
|
|
|
def generateQn(self, img_path, n):
|
|
|
|
prompt ='''
|
|
<|user|>\n<image>\nDescribe this image in a passage<|end|><|assistant|>
|
|
'''
|
|
artifacts = []
|
|
img_obj = Image.open(img_path)
|
|
|
|
|
|
inputs = self.processor(prompt, img_obj, return_tensors='pt').to(0)
|
|
|
|
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500)
|
|
index = torch.where(output[0]==32001)[0][-1].item()
|
|
desc = self.processor.decode(output[0][index:], skip_special_tokens=True)
|
|
|
|
prompt += "{}<|end|>\n<|user|> {}<|end|><|assistant|> ".format(desc,"Generate a simple question")
|
|
inputs = self.processor(prompt, img_obj, return_tensors='pt').to(0)
|
|
|
|
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500, do_sample=False, num_beams=3,num_beam_groups=3,diversity_penalty=10.0, num_return_sequences=n)
|
|
for out in output:
|
|
entry = {}
|
|
index = torch.where(out==32001)[0][-1].item()
|
|
text = self.processor.decode(out[index:], skip_special_tokens=True)
|
|
entry.update({"desc":desc})
|
|
entry.update({"question":text})
|
|
entry.update(self.genChoice(text, prompt, img_obj))
|
|
artifacts.append(entry)
|
|
|
|
return artifacts
|
|
|
|
|