aa / ovis /train /arguments.py
root
Ajout du module Ovis
b4942cf
raw
history blame
2 kB
from dataclasses import dataclass, field
from typing import Optional
import transformers
@dataclass
class ModelArguments:
llm_name_or_path: Optional[str] = field(default=None)
visual_tokenizer_type: str = field(default=None)
visual_vocab_size: int = field(default=8192)
visual_drop_cls_token: bool = field(default=False)
visual_tokenize_function: str = field(default='softmax')
visual_tau: float = field(default=1.0)
visual_depths: Optional[str] = field(default=None)
visual_hidden_stride: int = field(default=1)
multimodal_max_length: int = field(default=2048)
conversation_formatter_class: str = field(default=None)
pad_token_id: Optional[int] = field(default=None)
llm_attn_implementation: Optional[str] = field(default=None)
disable_tie_weight: bool = field(default=False)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
dataset_names: Optional[str] = field(default=None) # a|b|c
dataset_info: Optional[str] = field(default='dataset_info_v1_6')
ovis_pretrained_path: Optional[str] = field(default=None)
visual_tokenizer_pretrained_path: Optional[str] = field(default=None)
caption_template: Optional[str] = field(default=None)
stage: Optional[int] = field(default=None)
train_modules: Optional[str] = field(default=None)
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
visual_max_tau: float = field(default=5.0)
visual_min_tau: float = field(default=0.05)
save_safetensors: bool = field(default=True)
monitor_step: int = field(default=100)
vte_re_init: bool = field(default=False)
text_max_length: int = field(default=1024)
max_partitions: str = field(default="9|1|1")
def __post_init__(self):
if self.gradient_checkpointing:
self.gradient_checkpointing_kwargs = {"use_reentrant": False}
if self.stage < 3:
self.save_safetensors = False
super().__post_init__()