Jaykintecblic commited on
Commit
7100343
1 Parent(s): e64b9d9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -18
handler.py CHANGED
@@ -1,9 +1,15 @@
1
- from typing import Dict, Any, List
 
 
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoProcessor
5
  from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
6
  from transformers.image_transforms import resize, to_channel_dimension_format
 
 
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_path: str):
@@ -36,17 +42,7 @@ class EndpointHandler:
36
  image = to_channel_dimension_format(image, ChannelDimension.FIRST)
37
  return torch.tensor(image)
38
 
39
- def generate_responses(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
40
- results = []
41
- image = data.get("inputs")
42
-
43
- if isinstance(image, str):
44
- try:
45
- image = Image.open(image)
46
- except Exception as e:
47
- results.append({"error": f"Failed to open image: {e}"})
48
- return results
49
-
50
  try:
51
  inputs = self.processor.tokenizer(
52
  f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
@@ -58,14 +54,20 @@ class EndpointHandler:
58
 
59
  generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True)
60
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
61
- results.append({"label": generated_text, "score": 1.0})
62
 
63
  except torch.cuda.CudaError as e:
64
- results.append({"error": f"CUDA error: {e}"})
65
  except Exception as e:
66
- results.append({"error": f"Unexpected error: {e}"})
 
 
67
 
68
- return results
 
 
 
69
 
70
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
71
- return self.generate_responses(data)
 
 
1
+ from typing import Dict, Any
2
+ from fastapi import FastAPI, File, UploadFile
3
+ from fastapi.responses import StreamingResponse
4
  from PIL import Image
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoProcessor
7
  from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
8
  from transformers.image_transforms import resize, to_channel_dimension_format
9
+ import json
10
+ import io
11
+
12
+ app = FastAPI()
13
 
14
  class EndpointHandler:
15
  def __init__(self, model_path: str):
 
42
  image = to_channel_dimension_format(image, ChannelDimension.FIRST)
43
  return torch.tensor(image)
44
 
45
+ async def generate_responses(self, image: Image.Image):
 
 
 
 
 
 
 
 
 
 
46
  try:
47
  inputs = self.processor.tokenizer(
48
  f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
 
54
 
55
  generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True)
56
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
57
+ yield json.dumps({"label": generated_text, "score": 1.0}) + '\n'
58
 
59
  except torch.cuda.CudaError as e:
60
+ yield json.dumps({"error": f"CUDA error: {e}"}) + '\n'
61
  except Exception as e:
62
+ yield json.dumps({"error": f"Unexpected error: {e}"}) + '\n'
63
+
64
+ handler = EndpointHandler(model_path="path/to/your/model")
65
 
66
+ @app.post("/")
67
+ async def handle_request(file: UploadFile = File(...)):
68
+ image = Image.open(io.BytesIO(await file.read()))
69
+ return StreamingResponse(handler.generate_responses(image), media_type="application/json")
70
 
71
+ if __name__ == "__main__":
72
+ import uvicorn
73
+ uvicorn.run(app, host="0.0.0.0", port=8080)