Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import inspect | |
from typing import Any, Dict, List | |
from fairseq import metrics, utils | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.dataclass.utils import gen_parser_from_dataclass | |
from torch.nn.modules.loss import _Loss | |
class FairseqCriterion(_Loss): | |
def __init__(self, task): | |
super().__init__() | |
self.task = task | |
if hasattr(task, "target_dictionary"): | |
tgt_dict = task.target_dictionary | |
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 | |
def add_args(cls, parser): | |
"""Add criterion-specific arguments to the parser.""" | |
dc = getattr(cls, "__dataclass", None) | |
if dc is not None: | |
gen_parser_from_dataclass(parser, dc()) | |
def build_criterion(cls, cfg: FairseqDataclass, task): | |
"""Construct a criterion from command-line args.""" | |
# arguments in the __init__. | |
init_args = {} | |
for p in inspect.signature(cls).parameters.values(): | |
if ( | |
p.kind == p.POSITIONAL_ONLY | |
or p.kind == p.VAR_POSITIONAL | |
or p.kind == p.VAR_KEYWORD | |
): | |
# we haven't implemented inference for these argument types, | |
# but PRs welcome :) | |
raise NotImplementedError("{} not supported".format(p.kind)) | |
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} | |
if p.name == "task": | |
init_args["task"] = task | |
elif p.name == "cfg": | |
init_args["cfg"] = cfg | |
elif hasattr(cfg, p.name): | |
init_args[p.name] = getattr(cfg, p.name) | |
elif p.default != p.empty: | |
pass # we'll use the default value | |
else: | |
raise NotImplementedError( | |
"Unable to infer Criterion arguments, please implement " | |
"{}.build_criterion".format(cls.__name__) | |
) | |
return cls(**init_args) | |
def forward(self, model, sample, reduce=True): | |
"""Compute the loss for the given sample. | |
Returns a tuple with three elements: | |
1) the loss | |
2) the sample size, which is used as the denominator for the gradient | |
3) logging outputs to display while training | |
""" | |
raise NotImplementedError | |
def aggregate_logging_outputs( | |
logging_outputs: List[Dict[str, Any]] | |
) -> Dict[str, Any]: | |
"""Aggregate logging outputs from data parallel training.""" | |
utils.deprecation_warning( | |
"The aggregate_logging_outputs API is deprecated. " | |
"Please use the reduce_metrics API instead." | |
) | |
raise NotImplementedError | |
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: | |
"""Aggregate logging outputs from data parallel training.""" | |
utils.deprecation_warning( | |
"Criterions should implement the reduce_metrics API. " | |
"Falling back to deprecated aggregate_logging_outputs API." | |
) | |
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) | |
for k, v in agg_logging_outputs.items(): | |
if k in {"nsentences", "ntokens", "sample_size"}: | |
continue | |
metrics.log_scalar(k, v) | |
def logging_outputs_can_be_summed() -> bool: | |
""" | |
Whether the logging outputs returned by `forward` can be summed | |
across workers prior to calling `reduce_metrics`. Setting this | |
to True will improves distributed training speed. | |
""" | |
return False | |
class LegacyFairseqCriterion(FairseqCriterion): | |
def __init__(self, args, task): | |
super().__init__(task=task) | |
self.args = args | |
utils.deprecation_warning( | |
"Criterions should take explicit arguments instead of an " | |
"argparse.Namespace object, please update your criterion by " | |
"extending FairseqCriterion instead of LegacyFairseqCriterion." | |
) | |
def build_criterion(cls, args, task): | |
"""Construct a criterion from command-line args.""" | |
return cls(args, task) | |