SEED-Story / src /data /datapipes.py
xinlai's picture
seedx
674d663
import torchdata.datapipes as dp
import os
import tarfile
from torchdata.datapipes.iter import TarArchiveLoader
from typing import cast, IO, Iterable, Iterator, Optional, Tuple, Dict
from torchdata.datapipes import functional_datapipe
from io import BufferedIOBase
from torchdata.datapipes.utils import StreamWrapper
from torchdata.datapipes.utils.common import validate_pathname_binary_tuple
import warnings
from torchdata.datapipes.iter import IterDataPipe
import json
@functional_datapipe("load_from_tar_wo_exception")
class TarArchiveLoaderWoException(TarArchiveLoader):
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
for data in self.datapipe:
validate_pathname_binary_tuple(data)
pathname, data_stream = data
try:
if isinstance(data_stream, StreamWrapper) and isinstance(data_stream.file_obj, tarfile.TarFile):
tar = data_stream.file_obj
else:
reading_mode = (self.mode if hasattr(data_stream, "seekable") and data_stream.seekable() else
self.mode.replace(":", "|"))
# typing.cast is used here to silence mypy's type checker
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=reading_mode)
for tarinfo in tar:
if not tarinfo.isfile():
continue
extracted_fobj = tar.extractfile(tarinfo)
if extracted_fobj is None:
warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
raise tarfile.ExtractError
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream,
name=inner_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
# raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()
@functional_datapipe("parse_jsonl_files")
class JsonlParserIterDataPipe(IterDataPipe[Tuple[str, Dict]]):
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO]], **kwargs) -> None:
self.source_datapipe: IterDataPipe[Tuple[str, IO]] = source_datapipe
self.kwargs = kwargs
def __iter__(self) -> Iterator[Tuple[str, Dict]]:
for file_name, stream in self.source_datapipe:
for idx, line in enumerate(stream):
if line.strip() != '':
try:
yield f'{file_name}_line{idx}', json.loads(line)
except Exception as e:
warnings.warn(f"Error occured when parsing string to json due to: {e} abort!")