File size: 2,913 Bytes
9e0a81f
 
 
 
 
 
 
 
 
9467c14
 
 
9e0a81f
 
 
 
 
 
 
 
 
9467c14
 
9e0a81f
 
9467c14
 
 
 
 
 
 
9e0a81f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# +
from typing import  Dict, List, Any
from PIL import Image
import torch
import os
import io
import base64
from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor
# from transformers import Blip2Processor, Blip2ForConditionalGeneration
from transformers import Blip2ForConditionalGeneration, AutoProcessor
from peft import PeftModel, PeftConfig

# -

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        print("####### Start Deploying #####")
        # self.processor = Blip2Processor.from_pretrained("ChirathD/Blip-2-test-1")
        # self.model = Blip2ForConditionalGeneration.from_pretrained("ChirathD/Blip-2-test-1") 
        # self.model.eval()
        # self.model = self.model.to(device)

        peft_model_id = "ChirathD/Blip-2-test-4"
        config = PeftConfig.from_pretrained(peft_model_id)
        
        self.model = Blip2ForConditionalGeneration.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map="auto")
        self.model = PeftModel.from_pretrained(self.model, peft_model_id)
        self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
        


    def __call__(self, data: Any) -> Dict[str, Any]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
                - "caption": A string corresponding to the generated caption.
        """
        print(data)
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        print(input)
        image_bytes = base64.b64decode(inputs)
        image_io = io.BytesIO(image_bytes)
        image = Image.open(image_io)

        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs.pixel_values
        
        generated_ids = self.model.generate(pixel_values=pixel_values, max_length=25)
        generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print(generated_caption)
 
        # raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
                                     
        # processed_image = self.processor(images=raw_images, return_tensors="pt") 
        # processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
        # processed_image = {**processed_image, **parameters}
        
        # with torch.no_grad():
        #     out = self.model.generate(
        #         **processed_image
        #     )
        # captions = self.processor.batch_decode(out, skip_special_tokens=True)

        return {"captions": generated_caption}