Spaces:
Runtime error
Runtime error
import torchdata.datapipes as dp | |
import json | |
from PIL import Image | |
import functools | |
import numpy as np | |
import torch | |
import pickle | |
import os | |
import cv2 | |
import random | |
from torchvision import transforms | |
from braceexpand import braceexpand | |
import hydra | |
from random import choice | |
import tarfile | |
from torchdata.datapipes.iter import TarArchiveLoader | |
from typing import cast, IO, Iterable, Iterator, Optional, Tuple, Dict | |
from torchdata.datapipes import functional_datapipe | |
from io import BufferedIOBase | |
from torchdata.datapipes.utils import StreamWrapper | |
from torchdata.datapipes.utils.common import validate_pathname_binary_tuple | |
import warnings | |
from torchdata.datapipes.iter import IterDataPipe | |
import pyrootutils | |
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) | |
BOI_TOKEN = '<img>' | |
EOI_TOKEN = '</img>' | |
IMG_TOKEN = '<img_{:05d}>' | |
gen_prompt = [ | |
"Please show me a picture of ", | |
"Please design an image of ", | |
"Please produce a photo of ", | |
"Please generate an image of ", | |
"Please draw a painting of ", | |
"I'd like to see a drawing of ", | |
"I'd love to see an illustration of ", | |
"I'd like to view an image of ", | |
"I want to see a picture of ", | |
"I would like to see a photo of ", | |
"Show me a photo of ", | |
"Generate a picture of ", | |
"Show me a photograph of ", | |
"Generate an image of ", | |
"Generate an image: ", | |
"Generate a picture: ", | |
"Generate a painting: ", | |
"Generate a photograph: ", | |
"Show me a photograph: ", | |
"Draw a picture: ", | |
"Draw a painting: ", | |
"Draw an image: ", | |
"Can you make an image of ", | |
"Can you draw a painting of ", | |
"Can you produce a picture of ", | |
"Can you generate a photo of ", | |
"Can you depict a picture of ", | |
"Can you show me an illustration of ", | |
] | |
gen_prompt_response = [ | |
"Here is a picture.", | |
"I have designed an image.", | |
"Here is a photo.", | |
"I have generated an image.", | |
"Here's a painting.", | |
"Here's a drawing.", | |
"Enjoy this illustration.", | |
"Take a look at this image.", | |
"Here is a picture.", | |
"I have created a photo.", | |
"Enjoy this photo.", | |
"I have generated a picture.", | |
"Here is a photograph.", | |
"Here's an image.", | |
"Certainly, here's an image.", | |
"Absolutely, here is a painting.", | |
"Sure, here is a picture.", | |
"Of course, here is a photo.", | |
"Certainly, please enjoy this picture.", | |
"Sure, please enjoy this illustration.", | |
"", | |
] | |
jdb_filter_vocab = ['watermark', 'watermark,', 'chaos 100', 'chaos 100,'] | |
def filter_data_with_image_ids(item): | |
if ('images' not in item): | |
# print(item['__key__']) | |
# print('filtered because no images') | |
return False | |
elif 'input_ids' not in item: | |
return False | |
else: | |
return True | |
def calculate_new_dimensions(height, width, target_size): | |
if height < width: | |
new_height = target_size | |
new_width = int(width * (target_size / height)) | |
else: | |
new_width = target_size | |
new_height = int(height * (target_size / width)) | |
return new_height, new_width | |
def unwarp_data(item): | |
unwarpped = {} | |
for key, value in item.items(): | |
if isinstance(value, dict): | |
unwarpped.update(value) | |
elif value is not None: | |
unwarpped[key] = value | |
if 'metadata' not in unwarpped: | |
unwarpped['metadata'] = '{}' | |
# if '__key__' in unwarpped: | |
# unwarpped['__key__'] = unwarpped['__key__'].split('/')[-1] | |
return unwarpped | |
# def filter_data_with_similarity(item, similarity_thr=0.2, min_resolution=180, min_aspect_ratio=0.666): | |
def filter_data_with_similarity(item, similarity_thr=0.2, assure_text=True): | |
if ('images' not in item): | |
# print(item['__key__']) | |
# print('filtered because no images') | |
return False | |
elif (not item.get('filter_flag', True)): | |
# print(item['__key__']) | |
# print('filtered because filter flag.') | |
return False | |
elif assure_text and ('text' not in item): | |
# print(item['__key__']) | |
# print('filtered because assure_text') | |
return False | |
else: | |
metadata = json.loads(item['metadata']) | |
if 'all_similarities' in metadata: | |
similarity = max(metadata['all_similarities']) | |
elif 'similarity' in metadata: | |
similarity = metadata['similarity'] | |
elif 'score' in metadata: | |
similarity = metadata['score'] | |
elif 'SCORE' in metadata: | |
similarity = metadata['SCORE'] | |
else: | |
similarity = None | |
if similarity is not None: | |
if similarity < similarity_thr: | |
# print(item['__key__']) | |
# print('filtered because similarity') | |
return False | |
return True | |
def single_turn_edit_collate(batch): | |
results = {} | |
keys = batch[0].keys() | |
for key in keys: | |
cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None] | |
if len(cur) == 0: | |
results[key] = None | |
elif isinstance(cur[0], torch.Tensor): | |
if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images']: | |
results[key] = torch.cat(cur, dim=0) | |
else: | |
results[key] = torch.stack(cur, dim=0) | |
else: | |
results[key] = cur | |
return results | |
def decode_t2i_data(item, | |
image_dir, | |
tokenizer, | |
image_transform=None, | |
sd_image_transform=None, | |
max_length=128, | |
min_resolution=400, | |
instruction_prompt='[INST] {instruction} [/INST]\n', | |
turn_sep='\n', | |
system_message='', | |
min_aspect_ratio=0.666, | |
num_img_in_tokens=64, | |
num_img_out_tokens=64): | |
key, value = item | |
if 'image' not in value or 'caption' not in value: | |
return {} | |
image_path = os.path.join(image_dir, value["image"]) | |
try: | |
image = Image.open(image_path).convert('RGB') | |
width, height = image.size | |
aspect_ratio = height / width | |
if height < min_resolution or width < min_resolution: | |
print(f'filtered because resolution: ({width},{height})') | |
return {} | |
if aspect_ratio < min_aspect_ratio or aspect_ratio > 1 / min_aspect_ratio: | |
print(f'filtered because aspect ratio: ({width},{height})') | |
return {} | |
### SD related | |
image_data = {} | |
if sd_image_transform is not None: | |
# image_data['original_sizes'] = torch.tensor([height, width]) | |
sd_image_tensor = sd_image_transform(image) | |
target_size = sd_image_tensor.shape[-2] | |
target_width, target_height = calculate_new_dimensions(height=height, width=width, target_size=target_size) | |
y1 = max(0, int(round((target_height - target_size) / 2.0))) | |
x1 = max(0, int(round((target_width - target_size) / 2.0))) | |
# image_data['crop_top_lefts'] = torch.tensor([y1, x1]) | |
image_data['time_ids'] = torch.tensor([height, width, y1, x1, target_size, target_size]) | |
image_data['sd_images'] = sd_image_tensor | |
if image_transform is not None: | |
image = image_transform(image) | |
except Exception as e: | |
print('Error while decode image: ', e) | |
return {} | |
input_ids = [] | |
labels = [] | |
input_text = '' | |
if system_message != '': | |
if not system_message.endswith('\n'): | |
system_message += '\n' | |
input_text += system_message | |
item_ids = tokenizer.encode(system_message, add_special_tokens=False) | |
item_labels = [-100] * len(item_ids) | |
input_ids.extend(item_ids) | |
labels.extend(item_labels) | |
caption = value["caption"] | |
image_cmp_tokens = BOI_TOKEN + ''.join( | |
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN | |
image_gen_tokens = BOI_TOKEN + ''.join( | |
[IMG_TOKEN.format(int(item)) for item in range(num_img_out_tokens)]) + EOI_TOKEN | |
instruction = instruction_prompt.format_map({'instruction': caption}) | |
response = image_gen_tokens | |
images = torch.stack([image], dim=0) | |
# print(instruction) | |
item_ids = tokenizer.encode(instruction, add_special_tokens=False) | |
item_labels = [-100] * len(item_ids) | |
input_text += instruction | |
input_ids.extend(item_ids) | |
labels.extend(item_labels) | |
item_ids = tokenizer.encode(response, add_special_tokens=False) | |
item_labels = item_ids | |
input_text += response | |
input_ids.extend(item_ids) | |
labels.extend(item_labels) | |
input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id] | |
attention_mask = [1] * len(input_ids) | |
labels = [-100] + labels + [tokenizer.eos_token_id] | |
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] | |
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] | |
ids_cmp_mask = [False] * len(input_ids) | |
ids_gen_mask = [False] * len(input_ids) | |
embeds_cmp_mask = [False] | |
embeds_gen_mask = [True] | |
# print(len(input_ids)) | |
if len(input_ids) >= max_length: | |
# input_ids = input_ids[:max_length] | |
# attention_mask = attention_mask[:max_length] | |
# labels = labels[:max_length] | |
# ids_cmp_mask = ids_cmp_mask[:max_length] | |
# ids_gen_mask = ids_gen_mask[:max_length] | |
# print('An edit sample has been removed because of max length. input_text: ', input_text) | |
return {} | |
else: | |
padding_length = max_length - len(input_ids) | |
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length | |
attention_mask = attention_mask + [0] * padding_length | |
labels = labels + [-100] * padding_length | |
ids_cmp_mask = ids_cmp_mask + [False] * padding_length | |
ids_gen_mask = ids_gen_mask + [False] * padding_length | |
input_ids = torch.tensor(input_ids, dtype=torch.long) | |
attention_mask = torch.tensor(attention_mask, dtype=torch.long) | |
labels = torch.tensor(labels, dtype=torch.long) | |
ids_cmp_mask = torch.tensor(ids_cmp_mask, dtype=torch.bool) | |
ids_gen_mask = torch.tensor(ids_gen_mask, dtype=torch.bool) | |
embeds_cmp_mask = torch.tensor(embeds_cmp_mask) if embeds_cmp_mask is not None else None | |
embeds_gen_mask = torch.tensor(embeds_gen_mask) if embeds_gen_mask is not None else None | |
boi_idx = torch.where(input_ids == boi_token_id)[0].tolist() | |
eoi_idx = torch.where(input_ids == eoi_token_id)[0].tolist() | |
ids_gen_mask[boi_idx[0] + 1:eoi_idx[0]] = True | |
labels[boi_idx[0] + 1:eoi_idx[0] + 1] = -100 | |
ret = { | |
'input_ids': input_ids, | |
'attention_mask': attention_mask, | |
'labels': labels, | |
'ids_gen_mask': ids_gen_mask, | |
'ids_cmp_mask': ids_cmp_mask, | |
'embeds_gen_mask': embeds_gen_mask, | |
'embeds_cmp_mask': embeds_cmp_mask, | |
'images': images, | |
'text': input_text, | |
} | |
ret.update(image_data) | |
return ret | |
def build_t2i_datapipe(data_dir, | |
image_dir, | |
tokenizer=None, | |
max_length=77, | |
batch_size=None, | |
min_resolution=180, | |
image_transform=None, | |
sd_image_transform=None, | |
instruction_prompt='[INST] {instruction} [INST]\n', | |
turn_sep='\n', | |
system_message='', | |
min_aspect_ratio=0.666, | |
num_img_in_tokens=64, | |
num_img_out_tokens=64, | |
cycle_count=None): | |
decode_partial = functools.partial(decode_t2i_data, | |
image_dir=image_dir, | |
tokenizer=tokenizer, | |
image_transform=image_transform, | |
sd_image_transform=sd_image_transform, | |
max_length=max_length, | |
instruction_prompt=instruction_prompt, | |
turn_sep=turn_sep, | |
system_message=system_message, | |
min_resolution=min_resolution, | |
min_aspect_ratio=min_aspect_ratio, | |
num_img_in_tokens=num_img_in_tokens, | |
num_img_out_tokens=num_img_out_tokens) | |
filter_partial = functools.partial(filter_data_with_image_ids) | |
if isinstance(data_dir, str): | |
data_dir = list(braceexpand(data_dir)) | |
datapipe = dp.iter.FileLister(root=data_dir, masks='*.jsonl', recursive=True) | |
datapipe = datapipe.shuffle() | |
datapipe = datapipe.cycle(count=cycle_count) | |
datapipe = datapipe.shuffle() | |
# datapipe = dp.iter.FileLister(root=data_dir, masks='0000000.tar', recursive=True) | |
datapipe = datapipe.sharding_filter() | |
# datapipe = datapipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) | |
datapipe = datapipe.open_files(mode='r') | |
datapipe = datapipe.parse_jsonl_files() | |
datapipe = datapipe.map(decode_partial) | |
datapipe = datapipe.filter(filter_partial) | |
# datapipe = datapipe.shuffle(buffer_size=1024) | |
if batch_size is not None: | |
datapipe = datapipe.batch(batch_size) | |
datapipe = datapipe.collate(single_turn_edit_collate) | |
return datapipe | |
def decode_long_story_data(item, | |
image_dir, | |
tokenizer, | |
story_len, | |
image_transform=None, | |
sd_image_transform=None, | |
max_length=128, | |
min_resolution=400, | |
instruction_prompt='{instruction}', | |
turn_sep='\n', | |
system_message='', | |
min_aspect_ratio=0.666, | |
num_img_in_tokens=64, | |
num_img_out_tokens=64, ): | |
key, value = item | |
if 'images' not in value or 'captions' not in value: | |
return {} | |
image_paths = [os.path.join(image_dir, image_path) for image_path in value["images"]] | |
# assert len(image_paths) == story_len | |
story_len = len(image_paths) | |
num_image_given = random.randint(0, story_len - 2) | |
try: | |
images = [] | |
for image_path in image_paths: | |
image = Image.open(image_path).convert('RGB') | |
images.append(image) | |
width, height = image.size | |
aspect_ratio = height / width | |
if height < min_resolution or width < min_resolution: | |
print(f'filtered because resolution: ({width},{height})') | |
return {} | |
if aspect_ratio < min_aspect_ratio or aspect_ratio > 1 / min_aspect_ratio: | |
print(f'filtered because aspect ratio: ({width},{height})') | |
return {} | |
image_data = {} | |
sd_image = images[num_image_given + 1] | |
if sd_image_transform is not None: | |
# image_data['original_sizes'] = torch.tensor([height, width]) | |
sd_image_tensor = sd_image_transform(sd_image) | |
target_size = sd_image_tensor.shape[-2] | |
target_width, target_height = calculate_new_dimensions(height=height, width=width, target_size=target_size) | |
y1 = max(0, int(round((target_height - target_size) / 2.0))) | |
x1 = max(0, int(round((target_width - target_size) / 2.0))) | |
# image_data['crop_top_lefts'] = torch.tensor([y1, x1]) | |
image_data['time_ids'] = torch.tensor([height, width, y1, x1, target_size, target_size]) | |
image_data['sd_images'] = sd_image_tensor | |
if image_transform is not None: | |
for i in range(len(images)): | |
images[i] = image_transform(images[i]) | |
images = torch.stack(images, dim=0) | |
except Exception as e: | |
print('Error while decode image: ', e) | |
return {} | |
input_ids = [] | |
labels = [] | |
input_text = '' | |
if system_message != '': | |
if not system_message.endswith('\n'): | |
system_message += '\n' | |
input_text += system_message | |
item_ids = tokenizer.encode(system_message, add_special_tokens=False) | |
item_labels = [-100] * len(item_ids) | |
input_ids.extend(item_ids) | |
labels.extend(item_labels) | |
captions_all = [] | |
for i in range(story_len): | |
caption = value["captions"][i] | |
captions_all.append(caption) | |
image_cmp_tokens = BOI_TOKEN + ''.join( | |
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN | |
image_gen_tokens = BOI_TOKEN + ''.join( | |
[IMG_TOKEN.format(int(item)) for item in range(num_img_out_tokens)]) + EOI_TOKEN | |
instruction = instruction_prompt.format_map({'instruction': captions_all[0] + image_cmp_tokens}) | |
for i in range(num_image_given): | |
instruction = instruction + "[INST]" + captions_all[i + 1] + image_cmp_tokens | |
response = "[INST]" + captions_all[num_image_given + 1] + image_gen_tokens | |
images = images[:num_image_given + 2] | |
# print(instruction) | |
item_ids = tokenizer.encode(instruction, add_special_tokens=False) | |
item_labels = [-100] * len(item_ids) | |
input_text += instruction | |
input_ids.extend(item_ids) | |
labels.extend(item_labels) | |
item_ids = tokenizer.encode(response, add_special_tokens=False) | |
item_labels = item_ids | |
input_text += response | |
input_ids.extend(item_ids) | |
labels.extend(item_labels) | |
input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id] | |
attention_mask = [1] * len(input_ids) | |
labels = [-100] + labels + [tokenizer.eos_token_id] | |
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] | |
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] | |
ids_cmp_mask = [False] * len(input_ids) | |
ids_gen_mask = [False] * len(input_ids) | |
embeds_cmp_mask = [True] + [True] * num_image_given + [False] | |
embeds_gen_mask = [False] + [False] * num_image_given + [True] | |
# print(len(input_ids)) | |
if len(input_ids) >= max_length: | |
# input_ids = input_ids[:max_length] | |
# attention_mask = attention_mask[:max_length] | |
# labels = labels[:max_length] | |
# ids_cmp_mask = ids_cmp_mask[:max_length] | |
# ids_gen_mask = ids_gen_mask[:max_length] | |
# print('An edit sample has been removed because of max length. input_text: ', input_text) | |
return {} | |
else: | |
padding_length = max_length - len(input_ids) | |
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length | |
attention_mask = attention_mask + [0] * padding_length | |
labels = labels + [-100] * padding_length | |
ids_cmp_mask = ids_cmp_mask + [False] * padding_length | |
ids_gen_mask = ids_gen_mask + [False] * padding_length | |
input_ids = torch.tensor(input_ids, dtype=torch.long) | |
attention_mask = torch.tensor(attention_mask, dtype=torch.long) | |
labels = torch.tensor(labels, dtype=torch.long) | |
ids_cmp_mask = torch.tensor(ids_cmp_mask, dtype=torch.bool) | |
ids_gen_mask = torch.tensor(ids_gen_mask, dtype=torch.bool) | |
embeds_cmp_mask = torch.tensor(embeds_cmp_mask) if embeds_cmp_mask is not None else None | |
embeds_gen_mask = torch.tensor(embeds_gen_mask) if embeds_gen_mask is not None else None | |
boi_idx = torch.where(input_ids == boi_token_id)[0].tolist() | |
eoi_idx = torch.where(input_ids == eoi_token_id)[0].tolist() | |
ids_cmp_mask[boi_idx[0] + 1:eoi_idx[0]] = True | |
for i in range(num_image_given): | |
ids_cmp_mask[boi_idx[i + 1] + 1:eoi_idx[i + 1]] = True | |
ids_gen_mask[boi_idx[-1] + 1:eoi_idx[-1]] = True | |
labels[boi_idx[-1] + 1:eoi_idx[-1] + 1] = -100 | |
ret = { | |
'input_ids': input_ids, | |
'attention_mask': attention_mask, | |
'labels': labels, | |
'ids_gen_mask': ids_gen_mask, | |
'ids_cmp_mask': ids_cmp_mask, | |
'embeds_gen_mask': embeds_gen_mask, | |
'embeds_cmp_mask': embeds_cmp_mask, | |
'images': images, | |
'text': input_text, | |
} | |
ret.update(image_data) | |
return ret | |
def build_long_story_datapipe(data_dir, | |
image_dir, | |
tokenizer=None, | |
story_len=30, | |
max_length=77, | |
batch_size=None, | |
min_resolution=180, | |
image_transform=None, | |
sd_image_transform=None, | |
instruction_prompt='{instruction}', | |
turn_sep='\n', | |
system_message='', | |
min_aspect_ratio=0.666, | |
num_img_in_tokens=64, | |
num_img_out_tokens=64, | |
cycle_count=None): | |
decode_partial = functools.partial(decode_long_story_data, | |
image_dir=image_dir, | |
tokenizer=tokenizer, | |
story_len=story_len, | |
image_transform=image_transform, | |
sd_image_transform=sd_image_transform, | |
max_length=max_length, | |
instruction_prompt=instruction_prompt, | |
turn_sep=turn_sep, | |
system_message=system_message, | |
min_resolution=min_resolution, | |
min_aspect_ratio=min_aspect_ratio, | |
num_img_in_tokens=num_img_in_tokens, | |
num_img_out_tokens=num_img_out_tokens) | |
filter_partial = functools.partial(filter_data_with_image_ids) | |
if isinstance(data_dir, str): | |
data_dir = list(braceexpand(data_dir)) | |
datapipe = dp.iter.FileLister(root=data_dir, masks='*.jsonl', recursive=True) | |
datapipe = datapipe.shuffle() | |
datapipe = datapipe.cycle(count=cycle_count) | |
datapipe = datapipe.shuffle() | |
# datapipe = dp.iter.FileLister(root=data_dir, masks='0000000.tar', recursive=True) | |
datapipe = datapipe.sharding_filter() | |
# datapipe = datapipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) | |
datapipe = datapipe.open_files(mode='r') | |
datapipe = datapipe.parse_jsonl_files() | |
datapipe = datapipe.map(decode_partial) | |
datapipe = datapipe.filter(filter_partial) | |
# datapipe = datapipe.shuffle(buffer_size=1024) | |
if batch_size is not None: | |
datapipe = datapipe.batch(batch_size) | |
datapipe = datapipe.collate(single_turn_edit_collate) | |
return datapipe | |
def build_multi_datapipes(datapipes, tokenizer=None, image_transform=None, sd_image_transform=None, | |
sample_weights=None): | |
# assert concat_type in ['concat', 'mux_longest', 'sample'] | |
if sample_weights is None: | |
sample_weights = [1] * len(datapipes) | |
else: | |
assert len(sample_weights) == len(datapipes) | |
datapipes = [ | |
hydra.utils.instantiate(datapipe, tokenizer=tokenizer, image_transform=image_transform, | |
sd_image_transform=sd_image_transform) for datapipe in datapipes | |
] | |
datasets_to_weights_dict = {} | |
for dataset, sample_weight in zip(datapipes, sample_weights): | |
datasets_to_weights_dict[dataset] = sample_weight | |
datapipe = dp.iter.SampleMultiplexer(datasets_to_weights_dict) | |
return datapipe | |