Singularity666 commited on
Commit
b9658ca
·
verified ·
1 Parent(s): 883703e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +369 -98
main.py CHANGED
@@ -1,98 +1,369 @@
1
- import os
2
- import shutil
3
- import json
4
- import torch
5
- import random
6
- from pathlib import Path
7
- from torch.utils.data import Dataset
8
- from torchvision import transforms
9
- from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
10
- from transformers import CLIPTextModel, CLIPTokenizer
11
- from accelerate import Accelerator
12
- from tqdm.auto import tqdm
13
- from PIL import Image
14
-
15
- class CustomDataset(Dataset):
16
- def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False):
17
- self.data_dir = Path(data_dir)
18
- self.prompt = prompt
19
- self.tokenizer = tokenizer
20
- self.size = size
21
- self.center_crop = center_crop
22
-
23
- self.image_transforms = transforms.Compose([
24
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
25
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
26
- transforms.ToTensor(),
27
- transforms.Normalize([0.5], [0.5])
28
- ])
29
-
30
- self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")]
31
-
32
- def __len__(self):
33
- return len(self.images)
34
-
35
- def __getitem__(self, idx):
36
- image_path = self.images[idx]
37
- image = Image.open(image_path)
38
- if not image.mode == "RGB":
39
- image = image.convert("RGB")
40
-
41
- image = self.image_transforms(image)
42
- prompt_ids = self.tokenizer(
43
- self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
44
- ).input_ids
45
-
46
- return {"image": image, "prompt_ids": prompt_ids}
47
-
48
- def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800):
49
- # Setup
50
- accelerator = Accelerator(cpu=True)
51
- set_seed(seed)
52
-
53
- tokenizer = CLIPTokenizer.from_pretrained(model_name)
54
- text_encoder = CLIPTextModel.from_pretrained(model_name)
55
- vae = AutoencoderKL.from_pretrained(model_name)
56
- unet = UNet2DConditionModel.from_pretrained(model_name)
57
- noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
58
-
59
- dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution)
60
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
61
-
62
- optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)
63
-
64
- unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
65
- vae.to(accelerator.device)
66
- text_encoder.to(accelerator.device)
67
-
68
- global_step = 0
69
- for step, batch in tqdm(enumerate(dataloader), total=max_train_steps):
70
- latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215
71
- noise = torch.randn_like(latents)
72
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
73
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
74
- encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0]
75
-
76
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
77
-
78
- loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
79
- accelerator.backward(loss)
80
-
81
- optimizer.step()
82
- optimizer.zero_grad()
83
- global_step += 1
84
- if global_step >= max_train_steps:
85
- break
86
-
87
- # Save model
88
- unet = accelerator.unwrap_model(unet)
89
- unet.save_pretrained(output_dir)
90
- vae.save_pretrained(output_dir)
91
- text_encoder.save_pretrained(output_dir)
92
- tokenizer.save_pretrained(output_dir)
93
-
94
- def set_seed(seed):
95
- random.seed(seed)
96
- torch.manual_seed(seed)
97
- if torch.cuda.is_available():
98
- torch.cuda.manual_seed_all(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/conversation.py
2
+
3
+ import dataclasses
4
+ from enum import auto, Enum
5
+ from typing import List, Tuple
6
+
7
+
8
+ class SeparatorStyle(Enum):
9
+ """Different separator style."""
10
+ SINGLE = auto()
11
+ TWO = auto()
12
+ MPT = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Conversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ if self.sep_style == SeparatorStyle.SINGLE:
31
+ ret = self.system + self.sep
32
+ for role, message in self.messages:
33
+ if message:
34
+ if type(message) is tuple:
35
+ message, _, _ = message
36
+ ret += role + ": " + message + self.sep
37
+ else:
38
+ ret += role + ":"
39
+ return ret
40
+ elif self.sep_style == SeparatorStyle.TWO:
41
+ seps = [self.sep, self.sep2]
42
+ ret = self.system + seps[0]
43
+ for i, (role, message) in enumerate(self.messages):
44
+ if message:
45
+ if type(message) is tuple:
46
+ message, _, _ = message
47
+ ret += role + ": " + message + seps[i % 2]
48
+ else:
49
+ ret += role + ":"
50
+ return ret
51
+ if self.sep_style == SeparatorStyle.MPT:
52
+ ret = self.system + self.sep
53
+ for role, message in self.messages:
54
+ if message:
55
+ if type(message) is tuple:
56
+ message, _, _ = message
57
+ ret += role + message + self.sep
58
+ else:
59
+ ret += role
60
+ return ret
61
+ else:
62
+ raise ValueError(f"Invalid style: {self.sep_style}")
63
+
64
+ def append_message(self, role, message):
65
+ self.messages.append([role, message])
66
+
67
+ def get_images(self, return_pil=False):
68
+ images = []
69
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
70
+ if i % 2 == 0:
71
+ if type(msg) is tuple:
72
+ import base64
73
+ from io import BytesIO
74
+ from PIL import Image
75
+ msg, image, image_process_mode = msg
76
+ if image_process_mode == "Pad":
77
+ def expand2square(pil_img, background_color=(122, 116, 104)):
78
+ width, height = pil_img.size
79
+ if width == height:
80
+ return pil_img
81
+ elif width > height:
82
+ result = Image.new(pil_img.mode, (width, width), background_color)
83
+ result.paste(pil_img, (0, (width - height) // 2))
84
+ return result
85
+ else:
86
+ result = Image.new(pil_img.mode, (height, height), background_color)
87
+ result.paste(pil_img, ((height - width) // 2, 0))
88
+ return result
89
+ image = expand2square(image)
90
+ elif image_process_mode == "Crop":
91
+ pass
92
+ elif image_process_mode == "Resize":
93
+ image = image.resize((224, 224))
94
+ else:
95
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
96
+ max_hw, min_hw = max(image.size), min(image.size)
97
+ aspect_ratio = max_hw / min_hw
98
+ max_len, min_len = 800, 400
99
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
100
+ longest_edge = int(shortest_edge * aspect_ratio)
101
+ W, H = image.size
102
+ if H > W:
103
+ H, W = longest_edge, shortest_edge
104
+ else:
105
+ H, W = shortest_edge, longest_edge
106
+ image = image.resize((W, H))
107
+ if return_pil:
108
+ images.append(image)
109
+ else:
110
+ buffered = BytesIO()
111
+ image.save(buffered, format="JPEG")
112
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
113
+ images.append(img_b64_str)
114
+ return images
115
+
116
+ def to_gradio_chatbot(self):
117
+ ret = []
118
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
119
+ if i % 2 == 0:
120
+ if type(msg) is tuple:
121
+ import base64
122
+ from io import BytesIO
123
+ msg, image, image_process_mode = msg
124
+ max_hw, min_hw = max(image.size), min(image.size)
125
+ aspect_ratio = max_hw / min_hw
126
+ max_len, min_len = 800, 400
127
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
128
+ longest_edge = int(shortest_edge * aspect_ratio)
129
+ W, H = image.size
130
+ if H > W:
131
+ H, W = longest_edge, shortest_edge
132
+ else:
133
+ H, W = shortest_edge, longest_edge
134
+ image = image.resize((W, H))
135
+ # image = image.resize((224, 224))
136
+ buffered = BytesIO()
137
+ image.save(buffered, format="JPEG")
138
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
139
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
140
+ msg = msg.replace('<image>', img_str)
141
+ ret.append([msg, None])
142
+ else:
143
+ ret[-1][-1] = msg
144
+ return ret
145
+
146
+ def copy(self):
147
+ return Conversation(
148
+ system=self.system,
149
+ roles=self.roles,
150
+ messages=[[x, y] for x, y in self.messages],
151
+ offset=self.offset,
152
+ sep_style=self.sep_style,
153
+ sep=self.sep,
154
+ sep2=self.sep2)
155
+
156
+ def dict(self):
157
+ if len(self.get_images()) > 0:
158
+ return {
159
+ "system": self.system,
160
+ "roles": self.roles,
161
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
162
+ "offset": self.offset,
163
+ "sep": self.sep,
164
+ "sep2": self.sep2,
165
+ }
166
+ return {
167
+ "system": self.system,
168
+ "roles": self.roles,
169
+ "messages": self.messages,
170
+ "offset": self.offset,
171
+ "sep": self.sep,
172
+ "sep2": self.sep2,
173
+ }
174
+
175
+
176
+ conv_v1 = Conversation(
177
+ system="A chat between a curious human and an artificial intelligence assistant. "
178
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
179
+ roles=("Human", "Assistant"),
180
+ messages=(
181
+ ("Human", "Give three tips for staying healthy."),
182
+ ("Assistant",
183
+ "Sure, here are three tips for staying healthy:\n"
184
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
185
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
186
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
187
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
188
+ "activities at least two days per week.\n"
189
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
190
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
191
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
192
+ "and aim to drink plenty of water throughout the day.\n"
193
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
194
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
195
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
196
+ "help improve the quality of your sleep.")
197
+ ),
198
+ offset=2,
199
+ sep_style=SeparatorStyle.SINGLE,
200
+ sep="###",
201
+ )
202
+
203
+ conv_v1_2 = Conversation(
204
+ system="A chat between a curious human and an artificial intelligence assistant. "
205
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
206
+ roles=("Human", "Assistant"),
207
+ messages=(
208
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
209
+ ("Assistant",
210
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
211
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
212
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
213
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
214
+ "renewable and non-renewable energy sources:\n"
215
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
216
+ "energy sources are finite and will eventually run out.\n"
217
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
218
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
219
+ "and other negative effects.\n"
220
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
221
+ "have lower operational costs than non-renewable sources.\n"
222
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
223
+ "locations than non-renewable sources.\n"
224
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
225
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
226
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
227
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
228
+ ),
229
+ offset=2,
230
+ sep_style=SeparatorStyle.SINGLE,
231
+ sep="###",
232
+ )
233
+
234
+ conv_vicuna_v1_1 = Conversation(
235
+ system="A chat between a curious user and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
237
+ roles=("USER", "ASSISTANT"),
238
+ version="v1",
239
+ messages=(),
240
+ offset=0,
241
+ sep_style=SeparatorStyle.TWO,
242
+ sep=" ",
243
+ sep2="</s>",
244
+ )
245
+
246
+ conv_mpt = Conversation(
247
+ system="""<|im_start|>system
248
+ - You are a helpful language and vision assistant.
249
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
250
+ - You should follow the instructions carefully and explain your answers in detail.""",
251
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
252
+ version="mpt",
253
+ messages=(),
254
+ offset=0,
255
+ sep_style=SeparatorStyle.MPT,
256
+ sep="<|im_end|>",
257
+ )
258
+
259
+ conv_mpt_text = Conversation(
260
+ system="""<|im_start|>system
261
+ - You are a helpful assistant chatbot trained by MosaicML.
262
+ - You answer questions.
263
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
264
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
265
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
266
+ version="mpt",
267
+ messages=(),
268
+ offset=0,
269
+ sep_style=SeparatorStyle.MPT,
270
+ sep="<|im_end|>",
271
+ )
272
+
273
+ conv_bair_v1 = Conversation(
274
+ system="BEGINNING OF CONVERSATION:",
275
+ roles=("USER", "GPT"),
276
+ messages=(),
277
+ offset=0,
278
+ sep_style=SeparatorStyle.TWO,
279
+ sep=" ",
280
+ sep2="</s>",
281
+ )
282
+
283
+ simple_conv = Conversation(
284
+ system="A chat between a curious human and an artificial intelligence assistant. "
285
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
286
+ roles=("Human", "Assistant"),
287
+ messages=(
288
+ ("Human", "Hi!"),
289
+ ("Assistant", "Hi there! How can I help you today?")
290
+ ),
291
+ offset=2,
292
+ sep_style=SeparatorStyle.SINGLE,
293
+ sep="###",
294
+ )
295
+
296
+ simple_conv_multimodal = Conversation(
297
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
298
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
299
+ "Follow the instructions carefully and explain your answers in detail.",
300
+ roles=("Human", "Assistant"),
301
+ messages=(
302
+ ("Human", "Hi!"),
303
+ ("Assistant", "Hi there! How can I help you today?\n")
304
+ ),
305
+ offset=2,
306
+ sep_style=SeparatorStyle.SINGLE,
307
+ sep="###",
308
+ )
309
+
310
+ simple_conv_mpt_multimodal = Conversation(
311
+ system="""<|im_start|>system
312
+ - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
313
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
314
+ - You should follow the instructions carefully and explain your answers in detail.""",
315
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
316
+ version="mpt",
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.MPT,
320
+ sep="<|im_end|>",
321
+ )
322
+
323
+ simple_conv_legacy = Conversation(
324
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
325
+ "You are designed to assist human with a variety of tasks using natural language."
326
+ "Follow the instructions carefully.",
327
+ roles=("Human", "Assistant"),
328
+ messages=(
329
+ ("Human", "Hi!\n\n### Response:"),
330
+ ("Assistant", "Hi there! How can I help you today?\n")
331
+ ),
332
+ offset=2,
333
+ sep_style=SeparatorStyle.SINGLE,
334
+ sep="###",
335
+ )
336
+
337
+ conv_llava_v1 = Conversation(
338
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
339
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
340
+ "Follow the instructions carefully and explain your answers in detail.",
341
+ roles=("USER", "ASSISTANT"),
342
+ version="v1",
343
+ messages=(),
344
+ offset=0,
345
+ sep_style=SeparatorStyle.TWO,
346
+ sep=" ",
347
+ sep2="</s>",
348
+ )
349
+
350
+ default_conversation = conv_v1_2
351
+ conv_templates = {
352
+ "default": conv_v1_2,
353
+ "simple": simple_conv,
354
+ "simple_legacy": simple_conv_legacy,
355
+ "multimodal": simple_conv_multimodal,
356
+ "mpt_multimodal": simple_conv_mpt_multimodal,
357
+ "llava_v1": conv_llava_v1,
358
+
359
+ # fastchat
360
+ "v1": conv_v1_2,
361
+ "bair_v1": conv_bair_v1,
362
+ "vicuna_v1_1": conv_vicuna_v1_1,
363
+ "mpt": conv_mpt,
364
+ "mpt_text": conv_mpt_text,
365
+ }
366
+
367
+
368
+ if __name__ == "__main__":
369
+ print(default_conversation.get_prompt())