hong_seungbum
add application file
c7f5de3
raw
history blame
1.67 kB
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
from PIL import Image
from transformers import AutoProcessor, BlipForQuestionAnswering
import torch
from models import load_transformers
class vit_gpt2:
device = "cuda" if torch.cuda.is_available() else "cpu"
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def __init__(self, model_pretrain:str = "nlpconnect/vit-gpt2-image-captioning"):
self.model = VisionEncoderDecoderModel.from_pretrained(model_pretrain
, device_map={"": 0}, torch_dtype=torch.float16)
self.feature_extractor = ViTImageProcessor.from_pretrained(model_pretrain)
self.tokenizer = AutoTokenizer.from_pretrained(model_pretrain)
def image_captioning(self, image: Image.Image) -> str:
pixel_values = self.feature_extractor(images=[image], return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return preds[0]
def visual_question_answering(self, image: Image.Image, prompt: str) -> str:
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text