|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import transformers |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
llm_name_or_path: Optional[str] = field(default=None) |
|
visual_tokenizer_type: str = field(default=None) |
|
visual_vocab_size: int = field(default=8192) |
|
visual_drop_cls_token: bool = field(default=False) |
|
visual_tokenize_function: str = field(default='softmax') |
|
visual_tau: float = field(default=1.0) |
|
visual_depths: Optional[str] = field(default=None) |
|
visual_hidden_stride: int = field(default=1) |
|
multimodal_max_length: int = field(default=2048) |
|
conversation_formatter_class: str = field(default=None) |
|
pad_token_id: Optional[int] = field(default=None) |
|
llm_attn_implementation: Optional[str] = field(default=None) |
|
disable_tie_weight: bool = field(default=False) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
dataset_names: Optional[str] = field(default=None) |
|
dataset_info: Optional[str] = field(default='dataset_info_v1_6') |
|
ovis_pretrained_path: Optional[str] = field(default=None) |
|
visual_tokenizer_pretrained_path: Optional[str] = field(default=None) |
|
caption_template: Optional[str] = field(default=None) |
|
stage: Optional[int] = field(default=None) |
|
train_modules: Optional[str] = field(default=None) |
|
cache_dir: Optional[str] = field(default=None) |
|
optim: str = field(default="adamw_torch") |
|
visual_max_tau: float = field(default=5.0) |
|
visual_min_tau: float = field(default=0.05) |
|
save_safetensors: bool = field(default=True) |
|
monitor_step: int = field(default=100) |
|
vte_re_init: bool = field(default=False) |
|
text_max_length: int = field(default=1024) |
|
max_partitions: str = field(default="9|1|1") |
|
|
|
def __post_init__(self): |
|
if self.gradient_checkpointing: |
|
self.gradient_checkpointing_kwargs = {"use_reentrant": False} |
|
if self.stage < 3: |
|
self.save_safetensors = False |
|
super().__post_init__() |
|
|