from typing import Dict, Any from fastapi import FastAPI, File, UploadFile from fastapi.responses import StreamingResponse from PIL import Image import torch from transformers import AutoModelForCausalLM, AutoProcessor from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format import json import io app = FastAPI() class EndpointHandler: def __init__(self, model_path: str): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoProcessor.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device) self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents self.bos_token = self.processor.tokenizer.bos_token self.bad_words_ids = self.processor.tokenizer(["", ""], add_special_tokens=False).input_ids def convert_to_rgb(self, image: Image.Image) -> Image.Image: if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def custom_transform(self, image: Image.Image) -> torch.Tensor: image = self.convert_to_rgb(image) image = to_numpy_array(image) image = resize(image, (960, 960), resample=PILImageResampling.BILINEAR) image = self.processor.image_processor.rescale(image, scale=1 / 255) image = self.processor.image_processor.normalize( image, mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std ) image = to_channel_dimension_format(image, ChannelDimension.FIRST) return torch.tensor(image) async def generate_responses(self, image: Image.Image): try: inputs = self.processor.tokenizer( f"{self.bos_token}{'' * self.image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform) inputs = {k: v.to(self.device) for k, v in inputs.items()} generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] yield json.dumps({"label": generated_text, "score": 1.0}) + '\n' except torch.cuda.CudaError as e: yield json.dumps({"error": f"CUDA error: {e}"}) + '\n' except Exception as e: yield json.dumps({"error": f"Unexpected error: {e}"}) + '\n' handler = EndpointHandler(model_path="path/to/your/model") @app.post("/") async def handle_request(file: UploadFile = File(...)): image = Image.open(io.BytesIO(await file.read())) return StreamingResponse(handler.generate_responses(image), media_type="application/json") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8080)