marutitecblic commited on
Commit
c1393bf
1 Parent(s): 0fe8b25

Upload custom_image_to_text_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_image_to_text_pipeline.py +127 -0
custom_image_to_text_pipeline.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoModelForCausalLM, AutoProcessor
4
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
5
+ from transformers.image_transforms import resize, to_channel_dimension_format
6
+ import os
7
+ from typing import Dict, List, Any
8
+
9
+ # Constants
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # HF_TASK = os.getenv('HF_TASK')
13
+
14
+ # API_TOKEN = os.getenv('API_TOKEN') # Ensure you replace this with your actual API token
15
+
16
+ # # Load processor and model
17
+ # PROCESSOR = AutoProcessor.from_pretrained(
18
+ # "marutitecblic/HtmlTocode",
19
+ # trust_remote_code=True,
20
+ # # token=API_TOKEN,
21
+ # )
22
+ # MODEL = AutoModelForCausalLM.from_pretrained(
23
+ # "marutitecblic/HtmlTocode",
24
+ # # token=API_TOKEN,
25
+ # trust_remote_code=True,
26
+ # torch_dtype=torch.bfloat16,
27
+ # ).to(DEVICE)
28
+
29
+ # image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
30
+ # BOS_TOKEN = PROCESSOR.tokenizer.bos_token
31
+ # BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
32
+
33
+
34
+
35
+ # def preprocess(event):
36
+ # image = Image.open(event["file"]).convert("RGB")
37
+ # inputs = PROCESSOR.tokenizer(
38
+ # f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
39
+ # return_tensors="pt",
40
+ # add_special_tokens=False,
41
+ # )
42
+ # inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform)
43
+ # inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
44
+ # return inputs
45
+
46
+ # def inference(model_inputs):
47
+ # inputs = preprocess(model_inputs)
48
+ # generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
49
+ # generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
+ # return {"generated_text": generated_text}
51
+
52
+ # def postprocess(model_outputs):
53
+ # return model_outputs
54
+
55
+ # def handle(event, context):
56
+ # model_inputs = event
57
+ # model_outputs = inference(model_inputs)
58
+ # response = postprocess(model_outputs)
59
+ # return response
60
+
61
+ class ImageToTextPipeline:
62
+ def __init__(self,model_path:str):
63
+ # Load processor and model
64
+ self.PROCESSOR = AutoProcessor.from_pretrained(
65
+ model_path,
66
+ trust_remote_code=True,
67
+ # token=API_TOKEN,
68
+ )
69
+ self.MODEL = AutoModelForCausalLM.from_pretrained(
70
+ model_path,
71
+ # token=API_TOKEN,
72
+ trust_remote_code=True,
73
+ torch_dtype=torch.bfloat16,
74
+ ).to(DEVICE)
75
+ self.image_seq_len = self.MODEL.config.perceiver_config.resampler_n_latents
76
+ self.BOS_TOKEN = self.PROCESSOR.tokenizer.bos_token
77
+ self.BAD_WORDS_IDS = self.PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
78
+
79
+
80
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
81
+ # image = data.pop("inputs", data)
82
+
83
+ # # process image
84
+ # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
85
+
86
+ # # run prediction
87
+ # generated_ids = self.model.generate(pixel_values)
88
+
89
+ # # decode output
90
+ # prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
91
+ image = Image.open(data["file"]).convert("RGB")
92
+ inputs = self.PROCESSOR.tokenizer(
93
+ f"{self.BOS_TOKEN}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
94
+ return_tensors="pt",
95
+ add_special_tokens=False,
96
+ )
97
+ inputs["pixel_values"] = self.PROCESSOR.image_processor([image], transform=self.custom_transform)
98
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
99
+ # inputs = preprocess(model_inputs)
100
+ generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
101
+ generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
102
+ return {"text": generated_text}
103
+ # return {"text":prediction[0]}
104
+
105
+ # @classmethod
106
+ def convert_to_rgb(self, image):
107
+ if image.mode == "RGB":
108
+ return image
109
+ image_rgba = image.convert("RGBA")
110
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
111
+ alpha_composite = Image.alpha_composite(background, image_rgba)
112
+ alpha_composite = alpha_composite.convert("RGB")
113
+ return alpha_composite
114
+ # @classmethod
115
+ def custom_transform(self, x):
116
+ x = self.convert_to_rgb(x)
117
+ x = to_numpy_array(x)
118
+ x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
119
+ x = self.PROCESSOR.image_processor.rescale(x, scale=1 / 255)
120
+ x = self.PROCESSOR.image_processor.normalize(
121
+ x,
122
+ mean=self.PROCESSOR.image_processor.image_mean,
123
+ std=self.PROCESSOR.image_processor.image_std
124
+ )
125
+ x = to_channel_dimension_format(x, ChannelDimension.FIRST)
126
+ x = torch.tensor(x)
127
+ return x