Upload handler.py
Browse files- handler.py +6 -1
handler.py
CHANGED
@@ -3,13 +3,17 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
|
|
3 |
from PIL import Image
|
4 |
from io import BytesIO
|
5 |
import torch
|
6 |
-
import
|
7 |
|
|
|
|
|
|
|
8 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
|
10 |
|
11 |
class EndpointHandler():
|
12 |
def __init__(self, path=""):
|
|
|
13 |
self.model = BlipForConditionalGeneration.from_pretrained(
|
14 |
"quadranttechnologies/qhub-blip-image-captioning-finetuned").to(device)
|
15 |
self.processor = BlipProcessor.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned")
|
@@ -25,6 +29,7 @@ class EndpointHandler():
|
|
25 |
A :obj:`dict`:. The object returned should be a dict of one list like {"descriptions": ["Description of the image"]} containing :
|
26 |
- "description": A string corresponding to the generated description.
|
27 |
"""
|
|
|
28 |
|
29 |
images = data.pop("inputs", data)
|
30 |
text = data.get("text", "")
|
|
|
3 |
from PIL import Image
|
4 |
from io import BytesIO
|
5 |
import torch
|
6 |
+
import logging
|
7 |
|
8 |
+
|
9 |
+
logging.basicConfig(level=logging.DEBUG)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
13 |
|
14 |
class EndpointHandler():
|
15 |
def __init__(self, path=""):
|
16 |
+
logger.debug("Initializing model and processor.")
|
17 |
self.model = BlipForConditionalGeneration.from_pretrained(
|
18 |
"quadranttechnologies/qhub-blip-image-captioning-finetuned").to(device)
|
19 |
self.processor = BlipProcessor.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned")
|
|
|
29 |
A :obj:`dict`:. The object returned should be a dict of one list like {"descriptions": ["Description of the image"]} containing :
|
30 |
- "description": A string corresponding to the generated description.
|
31 |
"""
|
32 |
+
logger.debug(f"Received data keys: {data.keys()}")
|
33 |
|
34 |
images = data.pop("inputs", data)
|
35 |
text = data.get("text", "")
|