| import re | |
| import torch | |
| from torchvision import transforms | |
| from transformers import BlipForConditionalGeneration, BlipProcessor | |
| from internals.util.commons import download_image | |
| class Image2Text: | |
| def load(self): | |
| self.processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-large" | |
| ) | |
| self.model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16 | |
| ).to("cuda") | |
| def process(self, imageUrl: str) -> str: | |
| image = download_image(imageUrl).resize((512, 512)) | |
| inputs = self.processor.__call__(image, return_tensors="pt").to( | |
| "cuda", torch.float16 | |
| ) | |
| output_ids = self.model.generate( | |
| **inputs, do_sample=False, top_p=0.9, max_length=128 | |
| ) | |
| output_text = self.processor.batch_decode(output_ids) | |
| print(output_text) | |
| output_text = output_text[0] | |
| output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text) | |
| return output_text | |