from typing import Dict, List, Any from PIL import Image import torch from transformers import AutoModelForCausalLM, AutoProcessor from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format class EndpointHandler: def __init__(self, model_path: str): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoProcessor.from_pretrained( model_path, # token=api_token ) self.model = AutoModelForCausalLM.from_pretrained( model_path, # token=api_token, trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(self.device) self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents self.bos_token = self.processor.tokenizer.bos_token self.bad_words_ids = self.processor.tokenizer(["", ""], add_special_tokens=False).input_ids def convert_to_rgb(self, image: Image.Image) -> Image.Image: if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def custom_transform(self, image: Image.Image) -> torch.Tensor: image = self.convert_to_rgb(image) image = to_numpy_array(image) image = resize(image, (960, 960), resample=PILImageResampling.BILINEAR) image = self.processor.image_processor.rescale(image, scale=1 / 255) image = self.processor.image_processor.normalize( image, mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std ) image = to_channel_dimension_format(image, ChannelDimension.FIRST) return torch.tensor(image) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: image = data.get("inputs") if isinstance(image, str): image = Image.open(image) inputs = self.processor.tokenizer( f"{self.bos_token}{'' * self.image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform) inputs = {k: v.to(self.device) for k, v in inputs.items()} generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=4096) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # print(generated_text) # return {"text": generated_text} # Format the output as an array of dictionaries with 'label' and 'score' output = [{"label": text, "score": 1.0} for text in generated_text] return output