File size: 2,278 Bytes
d2525d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
config:
  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
  output_dir: !path ../../../models/biqwen2-warmup-256-newpad-0e
  processor:
    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
    class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor
    pretrained_model_name_or_path:  "./models/colqwen2_base" # "./models/paligemma-3b-mix-448"
    # max_length: 50

  model:
    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
    class_to_instanciate: !ext colpali_engine.models.BiQwen2
    pretrained_model_name_or_path: "./models/qwen2-warmup"
    torch_dtype:  !ext torch.bfloat16
    use_cache: false
    attn_implementation: "flash_attention_2"
#    device_map: "auto"
#    quantization_config:
#      (): transformers.BitsAndBytesConfig
#      load_in_4bit: true
#      bnb_4bit_quant_type: "nf4"
#      bnb_4bit_compute_dtype:  "bfloat16"
#      bnb_4bit_use_double_quant: true

  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set
  eval_dataset_loader: !import ../data/test_data.yaml

  # max_length: 50
  run_eval: true
  
  loss_func:
    (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss
  tr_args:
    (): transformers.training_args.TrainingArguments
    output_dir: null
    overwrite_output_dir: true
    num_train_epochs: 1
    per_device_train_batch_size: 64
    gradient_checkpointing: true
    gradient_checkpointing_kwargs: { "use_reentrant": false }
    # 6 x 8 gpus = 48 batch size
    # gradient_accumulation_steps: 4
    per_device_eval_batch_size: 64
    eval_strategy: "steps"
    dataloader_num_workers: 8
    # bf16: true
    save_steps: 500
    logging_steps: 10
    max_steps: 1
    eval_steps: 100
    warmup_steps: 100
    learning_rate: 5e-5
    save_total_limit: 1
    resume_from_checkpoint: false
    # optim: "paged_adamw_8bit"
  peft_config:
    (): peft.LoraConfig
    r: 32
    lora_alpha: 32
    lora_dropout: 0.1
    init_lora_weights: "gaussian"
    bias: "none"
    task_type: "FEATURE_EXTRACTION"
    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'