|
from dataclasses import dataclass, field
|
|
from .. import models
|
|
|
|
@dataclass
|
|
class RetroDataModelArguments:
|
|
pass
|
|
|
|
@dataclass
|
|
class DataArguments(RetroDataModelArguments):
|
|
max_seq_length: int = field(
|
|
default=512,
|
|
metadata={
|
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
max_answer_length: int = field(
|
|
default=30,
|
|
metadata={
|
|
"help": "Maximum length of an answer (in tokens) to be generated. This is not "
|
|
"a hard limit but the model's internal length limit."
|
|
},
|
|
)
|
|
doc_stride: int = field(
|
|
default=128,
|
|
metadata={
|
|
"help": "When splitting up a long document into chunks, how much stride to take between chunks."
|
|
},
|
|
)
|
|
return_token_type_ids: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Whether to return token type ids."
|
|
},
|
|
)
|
|
pad_to_max_length: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Whether to pad all samples to `max_seq_length`. "
|
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
|
"be faster on GPU but will be slower on TPU)."
|
|
},
|
|
)
|
|
preprocessing_num_workers: int = field(
|
|
default=5,
|
|
metadata={
|
|
"help": "The number of processes to use for the preprocessing."
|
|
},
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Overwrite the cached training and evaluation sets"
|
|
},
|
|
)
|
|
version_2_with_negative: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
null_score_diff_threshold: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "If null_score - best_non_null is greater than the threshold predict null."
|
|
},
|
|
)
|
|
rear_threshold: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "Rear threshold."
|
|
},
|
|
)
|
|
n_best_size: int = field(
|
|
default=20,
|
|
metadata={
|
|
"help": "The total number of n-best predictions to generate when looking for an answer."
|
|
},
|
|
)
|
|
use_choice_logits: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to use choice logits."
|
|
},
|
|
)
|
|
start_n_top: int = field(
|
|
default=-1,
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
end_n_top: int = field(
|
|
default=-1,
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
beta1: int = field(
|
|
default=1,
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
beta2: int = field(
|
|
default=1,
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
best_cof: int = field(
|
|
default=1,
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
|
|
@dataclass
|
|
class ModelArguments(RetroDataModelArguments):
|
|
use_auth_token: bool = field(
|
|
default=False,
|
|
metadata={
|
|
|
|
|
|
"help": ""
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SketchModelArguments(ModelArguments):
|
|
sketch_revision: str = field(
|
|
default="main",
|
|
metadata={
|
|
"help": "The revision of the pretrained sketch model."
|
|
},
|
|
)
|
|
sketch_model_name: str = field(
|
|
default="monologg/koelectra-small-v3-discriminator",
|
|
metadata={
|
|
"help": "The name of the pretrained sketch model."
|
|
},
|
|
)
|
|
sketch_model_mode: str = field(
|
|
default="finetune",
|
|
metadata={
|
|
"help": "Choices = ['finetune', 'transfer']"
|
|
},
|
|
)
|
|
sketch_tokenizer_name: str = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The name of the pretrained sketch tokenizer."
|
|
},
|
|
)
|
|
sketch_architectures: str = field(
|
|
default="ElectraForSequenceClassification",
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class IntensiveModelArguments(ModelArguments):
|
|
intensive_revision: str = field(
|
|
default="main",
|
|
metadata={
|
|
"help": "The revision of the pretrained intensive model."
|
|
},
|
|
)
|
|
intensive_model_name: str = field(
|
|
default="monologg/koelectra-base-v3-discriminator",
|
|
metadata={
|
|
"help": "The name of the pretrained intensive model."
|
|
},
|
|
)
|
|
intensive_model_mode: str = field(
|
|
default="finetune",
|
|
metadata={
|
|
"help": "Choices = ['finetune', 'transfer']"
|
|
},
|
|
)
|
|
intensive_tokenizer_name: str = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The name of the pretrained intensive tokenizer."
|
|
},
|
|
)
|
|
intensive_architectures: str = field(
|
|
default="ElectraForQuestionAnsweringAVPool",
|
|
metadata={
|
|
"help": ""
|
|
},
|
|
)
|
|
|
|
@dataclass
|
|
class RetroArguments(DataArguments, SketchModelArguments, IntensiveModelArguments):
|
|
def __post_init__(self):
|
|
|
|
model_cls = getattr(models, self.sketch_architectures, None)
|
|
if model_cls is None:
|
|
raise ValueError(f"The sketch architecture '{self.sketch_architectures}' is not supported.")
|
|
|
|
self.sketch_model_cls = model_cls
|
|
self.sketch_model_type = model_cls.model_type
|
|
if self.sketch_tokenizer_name is None:
|
|
self.sketch_tokenizer_name = self.sketch_model_name
|
|
|
|
|
|
model_cls = getattr(models, self.intensive_architectures, None)
|
|
if model_cls is None:
|
|
raise AttributeError
|
|
self.intensive_model_cls = model_cls
|
|
self.intensive_model_type = model_cls.model_type
|
|
|
|
|
|
if self.intensive_tokenizer_name is None:
|
|
self.intensive_tokenizer_name = self.intensive_model_name
|
|
|
|
|