from PIL import Image import torch from transformers import ( VisionEncoderDecoderModel, ViTImageProcessor, # Load extractor ViTModel, # Load ViT encoder ) MODEL = "kha-white/manga-ocr-base" print("Loading models") feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(MODEL, requires_grad=False) encoder: ViTModel = VisionEncoderDecoderModel.from_pretrained(MODEL).encoder if torch.cuda.is_available(): print('Using CUDA') encoder.cuda() else: print('Using CPU') def get_embeddings(images: list[Image.Image]) -> torch.Tensor: """Processes the images and returns their Embeddings""" images_rgb = [image.convert("RGB") for image in images] with torch.inference_mode(): pixel_values: torch.Tensor = feature_extractor(images_rgb, return_tensors="pt")["pixel_values"] return encoder(pixel_values.to(encoder.device))["pooler_output"].cpu()