File size: 1,667 Bytes
b4942cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import deepspeed
import torch
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

from ovis.util.constants import END_LINE, BEGIN_LINE
from ovis.util.utils import rank0_print


class TuneTauCallback(TrainerCallback):
    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        visual_tokenizer = kwargs['model'].get_visual_tokenizer()
        current_step = state.global_step
        max_step = state.max_steps
        ratio = current_step / max_step
        visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio


class MonitorCallback(TrainerCallback):
    def _monitoring(self, model, step):
        with torch.no_grad():
            with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()):
                for k, v in model.get_monitor_tensors().items():
                    rank0_print(BEGIN_LINE)
                    rank0_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ')
                    rank0_print(v)
                    rank0_print(END_LINE)

    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        model = kwargs['model']
        step = state.global_step
        if step % args.monitor_step == 0 or step == 10:  # monitor at step 10 for fast check
            self._monitoring(model, step)

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        model = kwargs['model']
        step = state.global_step
        self._monitoring(model, step)