Html / handler.py
Jaykintecblic's picture
Update handler.py
7100343 verified
raw
history blame
3.5 kB
from typing import Dict, Any
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
import json
import io
app = FastAPI()
class EndpointHandler:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device)
self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents
self.bos_token = self.processor.tokenizer.bos_token
self.bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
def convert_to_rgb(self, image: Image.Image) -> Image.Image:
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
def custom_transform(self, image: Image.Image) -> torch.Tensor:
image = self.convert_to_rgb(image)
image = to_numpy_array(image)
image = resize(image, (960, 960), resample=PILImageResampling.BILINEAR)
image = self.processor.image_processor.rescale(image, scale=1 / 255)
image = self.processor.image_processor.normalize(
image,
mean=self.processor.image_processor.image_mean,
std=self.processor.image_processor.image_std
)
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
return torch.tensor(image)
async def generate_responses(self, image: Image.Image):
try:
inputs = self.processor.tokenizer(
f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
return_tensors="pt",
add_special_tokens=False,
)
inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
yield json.dumps({"label": generated_text, "score": 1.0}) + '\n'
except torch.cuda.CudaError as e:
yield json.dumps({"error": f"CUDA error: {e}"}) + '\n'
except Exception as e:
yield json.dumps({"error": f"Unexpected error: {e}"}) + '\n'
handler = EndpointHandler(model_path="path/to/your/model")
@app.post("/")
async def handle_request(file: UploadFile = File(...)):
image = Image.open(io.BytesIO(await file.read()))
return StreamingResponse(handler.generate_responses(image), media_type="application/json")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)