sguna commited on
Commit
3cf4b90
1 Parent(s): 5c77eb1

Upload handler.py (#5)

Browse files

- Upload handler.py (cb666b7a0c24f5a2ee8497a4176d94d869aaa20f)

Files changed (1) hide show
  1. handler.py +9 -13
handler.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Any, Dict
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
@@ -31,20 +32,15 @@ class EndpointHandler():
31
  """
32
  logger.debug(f"Received data keys: {data.keys()}")
33
 
34
- images = data.pop("inputs", data)
35
- text = data.get("text", "")
36
- parameters = data.pop("parameters", {})
 
 
37
 
38
- try:
39
- # Ensure inputs is a list of image bytes, even if only a single image is provided
40
- if isinstance(images, bytes): # Single image as bytes
41
- raw_images = [Image.open(BytesIO(images))]
42
- elif isinstance(images, list): # Multiple images as list of bytes
43
- raw_images = [Image.open(BytesIO(_img)) for _img in images if isinstance(_img, bytes)]
44
- else:
45
- raise ValueError("Invalid image input format. Expected bytes or list of bytes.")
46
- except Exception as e:
47
- return {"error": f"Error fetching or processing image: {str(e)}"}
48
 
49
  processed_image = self.processor(images=images, text=text, return_tensors="pt")
50
  processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
 
1
+ import base64
2
  from typing import Any, Dict
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from PIL import Image
 
32
  """
33
  logger.debug(f"Received data keys: {data.keys()}")
34
 
35
+ image_base64 = data["inputs"].get("image")
36
+ image_data = base64.b64decode(image_base64)
37
+
38
+ # Convert image data to PIL Image
39
+ images = Image.open(BytesIO(image_data))
40
 
41
+ # Optional text input
42
+ text = data["inputs"].get("text", "")
43
+ parameters = data.pop("parameters", {})
 
 
 
 
 
 
 
44
 
45
  processed_image = self.processor(images=images, text=text, return_tensors="pt")
46
  processed_image["pixel_values"] = processed_image["pixel_values"].to(device)