jadechoghari commited on
Commit
ba1211e
1 Parent(s): 7ba7e0e

Create conversation.py

Browse files
Files changed (1) hide show
  1. conversation.py +275 -0
conversation.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+ VOCAB_IMAGE_W = 1000 # 224
6
+ VOCAB_IMAGE_H = 1000 # 224
7
+
8
+ class SeparatorStyle(Enum):
9
+ """Different separator style."""
10
+ SINGLE = auto()
11
+ TWO = auto()
12
+ MPT = auto()
13
+ PLAIN = auto()
14
+ LLAMA_2 = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+ first_round: bool = True
31
+
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>\n" + init_msg)
45
+
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message, _, _ = message
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.PLAIN:
96
+ seps = [self.sep, self.sep2]
97
+ ret = self.system
98
+ for i, (role, message) in enumerate(messages):
99
+ if message:
100
+ if type(message) is tuple:
101
+ message, _, _ = message
102
+ ret += message + seps[i % 2]
103
+ else:
104
+ ret += ""
105
+ else:
106
+ raise ValueError(f"Invalid style: {self.sep_style}")
107
+
108
+ return ret
109
+
110
+ def append_message(self, role, message):
111
+ self.messages.append([role, message])
112
+
113
+ def get_images(self, return_pil=False):
114
+ images = []
115
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
116
+ if i % 2 == 0:
117
+ if type(msg) is tuple:
118
+ import base64
119
+ from io import BytesIO
120
+ from PIL import Image
121
+ msg, image, image_process_mode = msg
122
+ if image_process_mode == "Pad":
123
+ def expand2square(pil_img, background_color=(122, 116, 104)):
124
+ width, height = pil_img.size
125
+ if width == height:
126
+ return pil_img
127
+ elif width > height:
128
+ result = Image.new(pil_img.mode, (width, width), background_color)
129
+ result.paste(pil_img, (0, (width - height) // 2))
130
+ return result
131
+ else:
132
+ result = Image.new(pil_img.mode, (height, height), background_color)
133
+ result.paste(pil_img, ((height - width) // 2, 0))
134
+ return result
135
+ image = expand2square(image)
136
+ elif image_process_mode == "Crop":
137
+ pass
138
+ elif image_process_mode == "Raw+Processor":
139
+ pass
140
+ elif image_process_mode == "Resize":
141
+ image = image.resize((336, 336))
142
+ else:
143
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
144
+
145
+ if image_process_mode != "Raw+Processor":
146
+ max_hw, min_hw = max(image.size), min(image.size)
147
+ aspect_ratio = max_hw / min_hw
148
+ max_len, min_len = 800, 400
149
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
150
+ longest_edge = int(shortest_edge * aspect_ratio)
151
+ W, H = image.size
152
+ if H > W:
153
+ H, W = longest_edge, shortest_edge
154
+ else:
155
+ H, W = shortest_edge, longest_edge
156
+ image = image.resize((W, H))
157
+ print('Input Image Size:{}'.format(image.size))
158
+
159
+ if return_pil:
160
+ images.append(image)
161
+ else:
162
+ buffered = BytesIO()
163
+ image.save(buffered, format="PNG")
164
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
165
+ images.append(img_b64_str)
166
+ return images
167
+
168
+ def to_gradio_chatbot(self):
169
+ ret = []
170
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
171
+ if i % 2 == 0:
172
+ if type(msg) is tuple:
173
+ import base64
174
+ from io import BytesIO
175
+ msg, image, image_process_mode = msg
176
+ if image_process_mode != "Raw+Processor":
177
+ max_hw, min_hw = max(image.size), min(image.size)
178
+ aspect_ratio = max_hw / min_hw
179
+ max_len, min_len = 800, 400
180
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
181
+ longest_edge = int(shortest_edge * aspect_ratio)
182
+ W, H = image.size
183
+ if H > W:
184
+ H, W = longest_edge, shortest_edge
185
+ else:
186
+ H, W = shortest_edge, longest_edge
187
+ image = image.resize((W, H))
188
+ buffered = BytesIO()
189
+ image.save(buffered, format="JPEG")
190
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
191
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
192
+ ret.append([img_str, None])
193
+ msg = msg.replace('<image>', '').strip()
194
+ if len(msg) > 0:
195
+ ret.append([msg, None])
196
+ else:
197
+ ret.append([msg, None])
198
+ else:
199
+ ret[-1][-1] = msg
200
+ return ret
201
+
202
+ def copy(self):
203
+ return Conversation(
204
+ system=self.system,
205
+ roles=self.roles,
206
+ messages=[[x, y] for x, y in self.messages],
207
+ offset=self.offset,
208
+ sep_style=self.sep_style,
209
+ sep=self.sep,
210
+ sep2=self.sep2,
211
+ version=self.version)
212
+
213
+ def dict(self):
214
+ if len(self.get_images()) > 0:
215
+ return {
216
+ "system": self.system,
217
+ "roles": self.roles,
218
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
219
+ "offset": self.offset,
220
+ "sep": self.sep,
221
+ "sep2": self.sep2,
222
+ }
223
+ return {
224
+ "system": self.system,
225
+ "roles": self.roles,
226
+ "messages": self.messages,
227
+ "offset": self.offset,
228
+ "sep": self.sep,
229
+ "sep2": self.sep2,
230
+ }
231
+
232
+
233
+
234
+ ferret_conv_vicuna_v1_original_system = Conversation(
235
+ system="A chat between a curious human and an artificial intelligence assistant. "
236
+ "Assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. "
237
+ "In images, points are represented by coordinates [x, y]. The top-left corner is [0, 0]. The bottom-right corner is [width-1, height-1]. "
238
+ "Increasing x moves right across the image while increasing y moves down. "
239
+ "A bounding box is marked by [x1, y1, x2, y2] with the top-left and bottom-right points being [x1, y1] and [x2, y2] respectively. "
240
+ f"The image size is assumed to be ({VOCAB_IMAGE_W}, {VOCAB_IMAGE_H}), i.e., width={VOCAB_IMAGE_W}, height={VOCAB_IMAGE_H}. "
241
+ "Follow the instructions carefully. ",
242
+ roles=("USER", "ASSISTANT"),
243
+ version="v1",
244
+ messages=(),
245
+ offset=0,
246
+ sep_style=SeparatorStyle.TWO,
247
+ sep=" ",
248
+ sep2="</s>",
249
+ )
250
+
251
+ ferret_conv_vicuna_v1 = Conversation(
252
+ system="A chat between a human and an AI that understands visuals. "
253
+ "In images, [x, y] denotes points: top-left [0, 0], bottom-right [width-1, height-1]. "
254
+ "Increasing x moves right; y moves down. "
255
+ f"Bounding box: [x1, y1, x2, y2]. Image size: {VOCAB_IMAGE_W}x{VOCAB_IMAGE_H}. "
256
+ "Follow instructions. ",
257
+ roles=("USER", "ASSISTANT"),
258
+ version="v1",
259
+ messages=(),
260
+ offset=0,
261
+ sep_style=SeparatorStyle.TWO,
262
+ sep=" ",
263
+ sep2="</s>",
264
+ )
265
+
266
+
267
+ default_conversation = ferret_conv_vicuna_v1
268
+ conv_templates = {
269
+ "v1": ferret_conv_vicuna_v1,
270
+ "ferret_v1": ferret_conv_vicuna_v1,
271
+ }
272
+
273
+
274
+ if __name__ == "__main__":
275
+ print(default_conversation.get_prompt())