sguna commited on
Commit
138ad6d
1 Parent(s): 8668e51

Upload handler.py (#4)

Browse files

- Upload handler.py (4f6f56e6ade5adb4fc5e89079d346f6007fffa0c)

Files changed (1) hide show
  1. handler.py +7 -3
handler.py CHANGED
@@ -3,12 +3,15 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
4
  from io import BytesIO
5
  import torch
 
6
 
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
- self.model = BlipForConditionalGeneration.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned").to(device)
 
12
  self.processor = BlipProcessor.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned")
13
  self.model.eval()
14
  self.model = self.model.to(device).to(device)
@@ -27,9 +30,9 @@ class EndpointHandler():
27
  text = data.get("text", "")
28
  parameters = data.pop("parameters", {})
29
 
30
- raw_images = Image.open(BytesIO(inputs)).convert("")
31
 
32
- processed_image = self.processor(images=raw_images, text = text, return_tensors="pt")
33
  processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
34
  processed_image = {**processed_image, **parameters}
35
 
@@ -41,4 +44,5 @@ class EndpointHandler():
41
 
42
  return {"description": description}
43
 
 
44
  handler = EndpointHandler()
 
3
  from PIL import Image
4
  from io import BytesIO
5
  import torch
6
+ import base64
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
+
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
+ self.model = BlipForConditionalGeneration.from_pretrained(
14
+ "quadranttechnologies/qhub-blip-image-captioning-finetuned").to(device)
15
  self.processor = BlipProcessor.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned")
16
  self.model.eval()
17
  self.model = self.model.to(device).to(device)
 
30
  text = data.get("text", "")
31
  parameters = data.pop("parameters", {})
32
 
33
+ raw_images = Image.open(BytesIO(base64.b64decode(inputs))).convert("RGB")
34
 
35
+ processed_image = self.processor(images=raw_images, text=text, return_tensors="pt")
36
  processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
37
  processed_image = {**processed_image, **parameters}
38
 
 
44
 
45
  return {"description": description}
46
 
47
+
48
  handler = EndpointHandler()