Wmcs91's picture
Duplicate from shi-labs/Versatile-Diffusion
fb53ec8
raw
history blame contribute delete
No virus
9.23 kB
import os
import os.path as osp
import numpy as np
import numpy.random as npr
import torch
import torch.distributed as dist
import torchvision
import copy
import itertools
from ... import sync
from ...cfg_holder import cfg_unique_holder as cfguh
from ...log_service import print_log
import torch.distributed as dist
from multiprocessing import shared_memory
import pickle
import hashlib
import random
class ds_base(torch.utils.data.Dataset):
def __init__(self,
cfg,
loader = None,
estimator = None,
transforms = None,
formatter = None):
self.cfg = cfg
self.load_info = None
self.init_load_info()
self.loader = loader
self.transforms = transforms
self.formatter = formatter
if self.load_info is not None:
load_info_order_by = getattr(self.cfg, 'load_info_order_by', 'default')
if load_info_order_by == 'default':
self.load_info = sorted(self.load_info, key=lambda x:x['unique_id'])
else:
try:
load_info_order_by, reverse = load_info_order_by.split('|')
reverse = reverse == 'reverse'
except:
reverse = False
self.load_info = sorted(
self.load_info, key=lambda x:x[load_info_order_by], reverse=reverse)
load_info_add_idx = getattr(self.cfg, 'load_info_add_idx', True)
if (self.load_info is not None) and load_info_add_idx:
for idx, info in enumerate(self.load_info):
info['idx'] = idx
if estimator is not None:
self.load_info = estimator(self.load_info)
self.try_sample = getattr(self.cfg, 'try_sample', None)
if self.try_sample is not None:
try:
start, end = self.try_sample
except:
start, end = 0, self.try_sample
self.load_info = self.load_info[start:end]
self.repeat = getattr(self.cfg, 'repeat', 1)
pick = getattr(self.cfg, 'pick', None)
if pick is not None:
self.load_info = [i for i in self.load_info if i['filename'] in pick]
#########
# cache #
#########
self.cache_sm = getattr(self.cfg, 'cache_sm', False)
self.cache_cnt = 0
if self.cache_sm:
self.cache_pct = getattr(self.cfg, 'cache_pct', 0)
cache_unique_id = sync.nodewise_sync().random_sync_id()
self.cache_unique_id = hashlib.sha256(pickle.dumps(cache_unique_id)).hexdigest()
self.__cache__(self.cache_pct)
#######
# log #
#######
if self.load_info is not None:
console_info = '{}: '.format(self.__class__.__name__)
console_info += 'total {} unique images, '.format(len(self.load_info))
console_info += 'total {} unique sample. Cached {}. Repeat {} times.'.format(
len(self.load_info), self.cache_cnt, self.repeat)
else:
console_info = '{}: load_info not ready.'.format(self.__class__.__name__)
print_log(console_info)
def init_load_info(self):
# implement by sub class
pass
def __len__(self):
return len(self.load_info)*self.repeat
def __cache__(self, pct):
if pct == 0:
self.cache_cnt = 0
return
self.cache_cnt = int(len(self.load_info)*pct)
if not self.cache_sm:
for i in range(self.cache_cnt):
self.load_info[i] = self.loader(self.load_info[i])
return
for i in range(self.cache_cnt):
shm_name = str(self.load_info[i]['unique_id']) + '_' + self.cache_unique_id
if i % self.local_world_size == self.local_rank:
data = pickle.dumps(self.loader(self.load_info[i]))
datan = len(data)
# self.print_smname_to_file(shm_name)
shm = shared_memory.SharedMemory(
name=shm_name, create=True, size=datan)
shm.buf[0:datan] = data[0:datan]
shm.close()
self.load_info[i] = shm_name
else:
self.load_info[i] = shm_name
dist.barrier()
def __getitem__(self, idx):
idx = idx%len(self.load_info)
# element = copy.deepcopy(self.load_info[idx])
# 0730 try shared memory
element = copy.deepcopy(self.load_info[idx])
if isinstance(element, str):
shm = shared_memory.SharedMemory(name=element)
element = pickle.loads(shm.buf)
shm.close()
else:
element = copy.deepcopy(element)
element['load_info_ptr'] = self.load_info
if idx >= self.cache_cnt:
element = self.loader(element)
if self.transforms is not None:
element = self.transforms(element)
if self.formatter is not None:
return self.formatter(element)
else:
return element
# 0730 try shared memory
def __del__(self):
# Clean the shared memory
for infoi in self.load_info:
if isinstance(infoi, str) and (self.local_rank==0):
shm = shared_memory.SharedMemory(name=infoi)
shm.close()
shm.unlink()
def print_smname_to_file(self, smname):
try:
log_file = cfguh().cfg.train.log_file
except:
try:
log_file = cfguh().cfg.eval.log_file
except:
raise ValueError
# a trick to use the log_file path
sm_file = log_file.replace('.log', '.smname')
with open(sm_file, 'a') as f:
f.write(smname + '\n')
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
from .ds_loader import get_loader
from .ds_transform import get_transform
from .ds_estimator import get_estimator
from .ds_formatter import get_formatter
@singleton
class get_dataset(object):
def __init__(self):
self.dataset = {}
def register(self, ds):
self.dataset[ds.__name__] = ds
def __call__(self, cfg):
if cfg is None:
return None
t = cfg.type
if t is None:
return None
elif t in ['laion2b', 'laion2b_dummy',
'laion2b_webdataset',
'laion2b_webdataset_sdofficial', ]:
from .. import ds_laion2b
elif t in ['coyo', 'coyo_dummy',
'coyo_webdataset', ]:
from .. import ds_coyo_webdataset
elif t in ['laionart', 'laionart_dummy',
'laionart_webdataset', ]:
from .. import ds_laionart
elif t in ['celeba']:
from .. import ds_celeba
elif t in ['div2k']:
from .. import ds_div2k
elif t in ['pafc']:
from .. import ds_pafc
elif t in ['coco_caption']:
from .. import ds_coco
else:
raise ValueError
loader = get_loader() (cfg.get('loader' , None))
transform = get_transform()(cfg.get('transform', None))
estimator = get_estimator()(cfg.get('estimator', None))
formatter = get_formatter()(cfg.get('formatter', None))
return self.dataset[t](
cfg, loader, estimator,
transform, formatter)
def register():
def wrapper(class_):
get_dataset().register(class_)
return class_
return wrapper
# some other helpers
class collate(object):
"""
Modified from torch.utils.data._utils.collate
It handle list different from the default.
List collate just by append each other.
"""
def __init__(self):
self.default_collate = \
torch.utils.data._utils.collate.default_collate
def __call__(self, batch):
"""
Args:
batch: [data, data] -or- [(data1, data2, ...), (data1, data2, ...)]
This function will not be used as induction function
"""
elem = batch[0]
if not (elem, (tuple, list)):
return self.default_collate(batch)
rv = []
# transposed
for i in zip(*batch):
if isinstance(i[0], list):
if len(i[0]) != 1:
raise ValueError
try:
i = [[self.default_collate(ii).squeeze(0)] for ii in i]
except:
pass
rvi = list(itertools.chain.from_iterable(i))
rv.append(rvi) # list concat
else:
rv.append(self.default_collate(i))
return rv