Clyine1's picture
Upload 2 files
c1400fc verified
import torch
from torch import nn
from transformers import LlavaForConditionalGeneration, LlavaConfig
import re
from PIL import Image
from random import randint
class VQApair(LlavaForConditionalGeneration):
config_class = LlavaConfig
def __init__(self, config, **kwargs):
super().__init__(config)
self.processor = kwargs.pop("proc")
def genChoice(self, question, base_prompt, img_obj):
base_prompt += "{}<|end|>\n<|user|> Suggest 1 correct answer<|end|><|assistant|> ".format(question)
inputs = self.processor(base_prompt, img_obj, return_tensors='pt').to(0)
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500)
index = torch.where(output[0]==32001)[0][-1].item()
answer = self.processor.decode(output[0][index:], skip_special_tokens=True)
base_prompt += "{}<|end|>\n<|user|> Suggest 3 incorrect answers<|end|><|assistant|> ".format(answer)
inputs = self.processor(base_prompt, img_obj, return_tensors='pt').to(0)
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500)
index = torch.where(output[0]==32001)[0][-1].item()
choices = self.processor.decode(output[0][index:], skip_special_tokens=True)
a = choices.split("\n")
a = [x[3:].strip() for x in a]
a = [x for x in a if x]
correct_answer = randint(0,len(a))
a.insert(correct_answer, answer)
a = ["{}) {}".format(i+1, a[i]) for i in range(len(a))]
ans = "Correct Answer: {}".format(a[correct_answer])
return {"Choices": a, "Answers": ans}
def generateQn(self, img_path, n):
#commands = ["Generate a simple question",""]
prompt ='''
<|user|>\n<image>\nDescribe this image in a passage<|end|><|assistant|>
'''
artifacts = []
img_obj = Image.open(img_path)
inputs = self.processor(prompt, img_obj, return_tensors='pt').to(0)
#Generate Desc
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500)
index = torch.where(output[0]==32001)[0][-1].item()
desc = self.processor.decode(output[0][index:], skip_special_tokens=True)
#Update Prompt to generate question
prompt += "{}<|end|>\n<|user|> {}<|end|><|assistant|> ".format(desc,"Generate a simple question")
inputs = self.processor(prompt, img_obj, return_tensors='pt').to(0)
#Generate k questions
output = self.generate(**inputs, eos_token_id=32007, max_new_tokens=500, do_sample=False, num_beams=3,num_beam_groups=3,diversity_penalty=10.0, num_return_sequences=n)
for out in output:
entry = {}
index = torch.where(out==32001)[0][-1].item()
text = self.processor.decode(out[index:], skip_special_tokens=True)
entry.update({"desc":desc})
entry.update({"question":text})
entry.update(self.genChoice(text, prompt, img_obj))
artifacts.append(entry)
return artifacts