jennysun's picture
Duplicate from gligen/demo
81ba850
raw
history blame
7.29 kB
import os
import os.path as op
import gc
import json
from typing import List
import logging
try:
from .blob_storage import BlobStorage, disk_usage
except:
class BlobStorage:
pass
def generate_lineidx(filein: str, idxout: str) -> None:
idxout_tmp = idxout + '.tmp'
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
fsize = os.fstat(tsvin.fileno()).st_size
fpos = 0
while fpos != fsize:
tsvout.write(str(fpos) + "\n")
tsvin.readline()
fpos = tsvin.tell()
os.rename(idxout_tmp, idxout)
def read_to_character(fp, c):
result = []
while True:
s = fp.read(32)
assert s != ''
if c in s:
result.append(s[: s.index(c)])
break
else:
result.append(s)
return ''.join(result)
class TSVFile(object):
def __init__(self,
tsv_file: str,
if_generate_lineidx: bool = False,
lineidx: str = None,
class_selector: List[str] = None,
blob_storage: BlobStorage = None):
self.tsv_file = tsv_file
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
if not lineidx else lineidx
self.linelist = op.splitext(tsv_file)[0] + '.linelist'
self.chunks = op.splitext(tsv_file)[0] + '.chunks'
self._fp = None
self._lineidx = None
self._sample_indices = None
self._class_boundaries = None
self._class_selector = class_selector
self._blob_storage = blob_storage
self._len = None
# the process always keeps the process which opens the file.
# If the pid is not equal to the currrent pid, we will re-open the file.
self.pid = None
# generate lineidx if not exist
if not op.isfile(self.lineidx) and if_generate_lineidx:
generate_lineidx(self.tsv_file, self.lineidx)
def __del__(self):
self.gcidx()
if self._fp:
self._fp.close()
# physically remove the tsv file if it is retrieved by BlobStorage
if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
try:
original_usage = disk_usage('/')
os.remove(self.tsv_file)
logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
(self.tsv_file, original_usage, disk_usage('/') * 100))
except:
# Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
# TODO: try Threadling.Lock to better handle the race condition
pass
def __str__(self):
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
def __repr__(self):
return str(self)
def gcidx(self):
logging.debug('Run gc collect')
self._lineidx = None
self._sample_indices = None
#self._class_boundaries = None
return gc.collect()
def get_class_boundaries(self):
return self._class_boundaries
def num_rows(self, gcf=False):
if (self._len is None):
self._ensure_lineidx_loaded()
retval = len(self._sample_indices)
if (gcf):
self.gcidx()
self._len = retval
return self._len
def seek(self, idx: int):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
try:
pos = self._lineidx[self._sample_indices[idx]]
except:
logging.info('=> {}-{}'.format(self.tsv_file, idx))
raise
self._fp.seek(pos)
return [s.strip() for s in self._fp.readline().split('\t')]
def seek_first_column(self, idx: int):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
pos = self._lineidx[idx]
self._fp.seek(pos)
return read_to_character(self._fp, '\t')
def get_key(self, idx: int):
return self.seek_first_column(idx)
def __getitem__(self, index: int):
return self.seek(index)
def __len__(self):
return self.num_rows()
def _ensure_lineidx_loaded(self):
if self._lineidx is None:
logging.debug('=> loading lineidx: {}'.format(self.lineidx))
with open(self.lineidx, 'r') as fp:
lines = fp.readlines()
lines = [line.strip() for line in lines]
self._lineidx = [int(line) for line in lines]
# read the line list if exists
linelist = None
if op.isfile(self.linelist):
with open(self.linelist, 'r') as fp:
linelist = sorted(
[
int(line.strip())
for line in fp.readlines()
]
)
if op.isfile(self.chunks):
self._sample_indices = []
self._class_boundaries = []
class_boundaries = json.load(open(self.chunks, 'r'))
for class_name, boundary in class_boundaries.items():
start = len(self._sample_indices)
if class_name in self._class_selector:
for idx in range(boundary[0], boundary[1] + 1):
# NOTE: potentially slow when linelist is long, try to speed it up
if linelist and idx not in linelist:
continue
self._sample_indices.append(idx)
end = len(self._sample_indices)
self._class_boundaries.append((start, end))
else:
if linelist:
self._sample_indices = linelist
else:
self._sample_indices = list(range(len(self._lineidx)))
def _ensure_tsv_opened(self):
if self._fp is None:
if self._blob_storage:
self._fp = self._blob_storage.open(self.tsv_file)
else:
self._fp = open(self.tsv_file, 'r')
self.pid = os.getpid()
if self.pid != os.getpid():
logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
self._fp = open(self.tsv_file, 'r')
self.pid = os.getpid()
class TSVWriter(object):
def __init__(self, tsv_file):
self.tsv_file = tsv_file
self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
self.tsv_file_tmp = self.tsv_file + '.tmp'
self.lineidx_file_tmp = self.lineidx_file + '.tmp'
self.tsv_fp = open(self.tsv_file_tmp, 'w')
self.lineidx_fp = open(self.lineidx_file_tmp, 'w')
self.idx = 0
def write(self, values, sep='\t'):
v = '{0}\n'.format(sep.join(map(str, values)))
self.tsv_fp.write(v)
self.lineidx_fp.write(str(self.idx) + '\n')
self.idx = self.idx + len(v)
def close(self):
self.tsv_fp.close()
self.lineidx_fp.close()
os.rename(self.tsv_file_tmp, self.tsv_file)
os.rename(self.lineidx_file_tmp, self.lineidx_file)