Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A common dataset reader.""" | |
from typing import Any, Callable, List, Optional | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
from official.modeling.hyperparams import config_definitions as cfg | |
class InputReader: | |
"""Input reader that returns a tf.data.Dataset instance.""" | |
def __init__(self, | |
params: cfg.DataConfig, | |
shards: Optional[List[str]] = None, | |
dataset_fn=tf.data.TFRecordDataset, | |
decoder_fn: Optional[Callable[..., Any]] = None, | |
parser_fn: Optional[Callable[..., Any]] = None, | |
dataset_transform_fn: Optional[Callable[[tf.data.Dataset], | |
tf.data.Dataset]] = None, | |
postprocess_fn: Optional[Callable[..., Any]] = None): | |
"""Initializes an InputReader instance. | |
Args: | |
params: A config_definitions.DataConfig object. | |
shards: A list of files to be read. If given, read from these files. | |
Otherwise, read from params.input_path. | |
dataset_fn: A `tf.data.Dataset` that consumes the input files. For | |
example, it can be `tf.data.TFRecordDataset`. | |
decoder_fn: An optional `callable` that takes the serialized data string | |
and decodes them into the raw tensor dictionary. | |
parser_fn: An optional `callable` that takes the decoded raw tensors dict | |
and parse them into a dictionary of tensors that can be consumed by the | |
model. It will be executed after decoder_fn. | |
dataset_transform_fn: An optional `callable` that takes a | |
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be | |
executed after parser_fn. | |
postprocess_fn: A optional `callable` that processes batched tensors. It | |
will be executed after batching. | |
""" | |
if params.input_path and params.tfds_name: | |
raise ValueError('At most one of `input_path` and `tfds_name` can be ' | |
'specified, but got %s and %s.' % ( | |
params.input_path, params.tfds_name)) | |
self._shards = shards | |
self._tfds_builder = None | |
if self._shards: | |
self._num_files = len(self._shards) | |
elif not params.tfds_name: | |
self._input_patterns = params.input_path.strip().split(',') | |
self._num_files = 0 | |
for input_pattern in self._input_patterns: | |
input_pattern = input_pattern.strip() | |
if not input_pattern: | |
continue | |
matched_files = tf.io.gfile.glob(input_pattern) | |
if not matched_files: | |
raise ValueError('%s does not match any files.' % input_pattern) | |
else: | |
self._num_files += len(matched_files) | |
if self._num_files == 0: | |
raise ValueError('%s does not match any files.' % params.input_path) | |
else: | |
if not params.tfds_split: | |
raise ValueError( | |
'`tfds_name` is %s, but `tfds_split` is not specified.' % | |
params.tfds_name) | |
self._tfds_builder = tfds.builder( | |
params.tfds_name, data_dir=params.tfds_data_dir) | |
self._global_batch_size = params.global_batch_size | |
self._is_training = params.is_training | |
self._drop_remainder = params.drop_remainder | |
self._shuffle_buffer_size = params.shuffle_buffer_size | |
self._cache = params.cache | |
self._cycle_length = params.cycle_length | |
self._block_length = params.block_length | |
self._sharding = params.sharding | |
self._examples_consume = params.examples_consume | |
self._tfds_split = params.tfds_split | |
self._tfds_download = params.tfds_download | |
self._tfds_as_supervised = params.tfds_as_supervised | |
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature | |
self._dataset_fn = dataset_fn | |
self._decoder_fn = decoder_fn | |
self._parser_fn = parser_fn | |
self._dataset_transform_fn = dataset_transform_fn | |
self._postprocess_fn = postprocess_fn | |
def _read_sharded_files( | |
self, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Reads a dataset from sharded files.""" | |
# Read from `self._shards` if it is provided. | |
if self._shards: | |
dataset = tf.data.Dataset.from_tensor_slices(self._shards) | |
else: | |
dataset = tf.data.Dataset.list_files( | |
self._input_patterns, shuffle=self._is_training) | |
if self._sharding and input_context and ( | |
input_context.num_input_pipelines > 1): | |
dataset = dataset.shard(input_context.num_input_pipelines, | |
input_context.input_pipeline_id) | |
if self._is_training: | |
dataset = dataset.repeat() | |
dataset = dataset.interleave( | |
map_func=self._dataset_fn, | |
cycle_length=self._cycle_length, | |
block_length=self._block_length, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
return dataset | |
def _read_single_file( | |
self, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Reads a dataset from a single file.""" | |
# Read from `self._shards` if it is provided. | |
dataset = self._dataset_fn(self._shards or self._input_patterns) | |
# When `input_file` is a path to a single file, disable auto sharding | |
# so that same input file is sent to all workers. | |
options = tf.data.Options() | |
options.experimental_distribute.auto_shard_policy = ( | |
tf.data.experimental.AutoShardPolicy.OFF) | |
dataset = dataset.with_options(options) | |
if self._sharding and input_context and ( | |
input_context.num_input_pipelines > 1): | |
dataset = dataset.shard(input_context.num_input_pipelines, | |
input_context.input_pipeline_id) | |
if self._is_training: | |
dataset = dataset.repeat() | |
return dataset | |
def _read_tfds( | |
self, | |
input_context: Optional[tf.distribute.InputContext] = None | |
) -> tf.data.Dataset: | |
"""Reads a dataset from tfds.""" | |
if self._tfds_download: | |
self._tfds_builder.download_and_prepare() | |
read_config = tfds.ReadConfig( | |
interleave_cycle_length=self._cycle_length, | |
interleave_block_length=self._block_length, | |
input_context=input_context) | |
decoders = {} | |
if self._tfds_skip_decoding_feature: | |
for skip_feature in self._tfds_skip_decoding_feature.split(','): | |
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() | |
dataset = self._tfds_builder.as_dataset( | |
split=self._tfds_split, | |
shuffle_files=self._is_training, | |
as_supervised=self._tfds_as_supervised, | |
decoders=decoders, | |
read_config=read_config) | |
return dataset | |
def tfds_info(self) -> tfds.core.DatasetInfo: | |
"""Returns TFDS dataset info, if available.""" | |
if self._tfds_builder: | |
return self._tfds_builder.info | |
else: | |
raise ValueError('tfds_info is not available, because the dataset ' | |
'is not loaded from tfds.') | |
def read( | |
self, | |
input_context: Optional[tf.distribute.InputContext] = None | |
) -> tf.data.Dataset: | |
"""Generates a tf.data.Dataset object.""" | |
if self._tfds_builder: | |
dataset = self._read_tfds(input_context) | |
elif self._num_files > 1: | |
dataset = self._read_sharded_files(input_context) | |
else: | |
assert self._num_files == 1 | |
dataset = self._read_single_file(input_context) | |
if self._cache: | |
dataset = dataset.cache() | |
if self._is_training: | |
dataset = dataset.shuffle(self._shuffle_buffer_size) | |
if self._examples_consume > 0: | |
dataset = dataset.take(self._examples_consume) | |
def maybe_map_fn(dataset, fn): | |
return dataset if fn is None else dataset.map( | |
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
dataset = maybe_map_fn(dataset, self._decoder_fn) | |
dataset = maybe_map_fn(dataset, self._parser_fn) | |
if self._dataset_transform_fn is not None: | |
dataset = self._dataset_transform_fn(dataset) | |
per_replica_batch_size = input_context.get_per_replica_batch_size( | |
self._global_batch_size) if input_context else self._global_batch_size | |
dataset = dataset.batch( | |
per_replica_batch_size, drop_remainder=self._drop_remainder) | |
dataset = maybe_map_fn(dataset, self._postprocess_fn) | |
return dataset.prefetch(tf.data.experimental.AUTOTUNE) | |