File size: 11,818 Bytes
550665c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import os
import gc
import time
import json
import math
import collections
from datetime import datetime
from typing import Optional, List, Dict, Tuple, Callable, Any, Union

import torch
import numpy as np

from transformers import (
    is_datasets_available,
    is_torch_tpu_available,
    is_torch_xla_available,
)

from transformers.trainer_utils import (
    PredictionOutput,
    EvalPrediction,
    EvalLoopOutput,
    denumpify_detensorize,
    speed_metrics,
)

from transformers.utils import logging
from transformers.debug_utils import DebugOption

if is_datasets_available():
    import datasets
    
# if is_torch_xla_available():
#     import torch_xla.core.xla_model as xm       # type: ignore
#     import torch_xla.debug.metrics as met       # type: ignore

from transformers import Trainer

logger = logging.get_logger(__name__)

class ToMixin:
    def _optimizer_to(self, devide: str = "cpu"):
        """

        Move the optimizer state to the specified device.



        Args:

            devide (str, optional): The device to move the optimizer state to. Defaults to "cpu".

        """
        for param in self.optimizer.state.values():
            if isinstance(param, torch.Tensor):
                param.data = param.data.to(devide)
                if param._grad is not None:
                    param._grad.data = param._grad.data.to(devide)
            elif isinstance(param, dict):
                for subparam in param.values():
                    if isinstance(subparam, torch.Tensor):
                        subparam.data = subparam.data.to(devide)
                        if subparam._grad is not None:
                            subparam._grad.data = subparam._grad.data.to(devide)
    
    def _scheduler_to(self, devide: str = "cpu") -> None:
        """

        Move the scheduler state to the specified device.



        Args:

            devide (str, optional): The device to move the scheduler state to. Defaults to "cpu".



        Returns:

            None

        """
        for param in self.lr_scheduler.__dict__.values():
            if isinstance(param, torch.Tensor):
                param.data = param.data.to(devide)
                if param._grad is not None:
                    param._grad.data = param._grad.data.to(devide)
                
class BaseReader(Trainer, ToMixin):
    name: str = None
    
    def __init__(

        self, 

        *args,  # Passed to Trainer.__init__

        data_args = {},  # Additional arguments for data loading

        eval_examples: datasets.Dataset = None,  # Evaluation examples

        **kwargs  # Passed to Trainer.__init__

    ):
        """

        Initializes the BaseReader.



        Args:

            *args: Positional arguments passed to Trainer.__init__.

            data_args (dict): Additional arguments for data loading.

            eval_examples (datasets.Dataset): Evaluation examples.

            **kwargs: Keyword arguments passed to Trainer.__init__.

        """
        # Call the parent class's __init__ method with the given arguments
        super().__init__(*args, **kwargs)
        
        # Set the data_args attribute
        self.data_args = data_args
        
        # Set the eval_examples attribute
        self.eval_examples = eval_examples
    
    def free_memory(self):
        """

        Move the model, optimizer and scheduler state to the CPU, empty the CUDA cache and garbage collect.



        This method is useful to free up GPU memory before checkpointing the model or saving it to disk.

        """
        self.model.to("cpu")
        self._optimizer_to("cpu")
        self._scheduler_to("cpu")
        torch.cuda.empty_cache()
        gc.collect()

        
    def postprocess(

        self,

        output: EvalLoopOutput,

    ) -> Union[Any, PredictionOutput]:
        """

        Preprocess the evaluation loop output.



        This method is called after the evaluation loop has finished and before the evaluation metrics are computed.

        It receives the output of the evaluation loop and can be used to modify it before it is passed to the compute_metrics function.



        Args:

            output (EvalLoopOutput): The output of the evaluation loop.



        Returns:

            Union[Any, PredictionOutput]: The modified output that will be passed to the compute_metrics function.

        """
        return output

    
    def evaluate(

        self,

        eval_dataset: Optional[datasets.Dataset] = None,

        eval_examples: Optional[datasets.Dataset] = None,

        ignore_keys: Optional[List[str]] = None,

        metric_key_prefix: str = "eval",

    ) -> Dict[str, float]:
        """

        Evaluate the model on the given dataset.



        Args:

            eval_dataset (Optional[datasets.Dataset], optional): The evaluation dataset. Defaults to None.

            eval_examples (Optional[datasets.Dataset], optional): The evaluation examples. Defaults to None.

            ignore_keys (Optional[List[str]], optional): Keys to ignore when calculating metrics. Defaults to None.

            metric_key_prefix (str, optional): The prefix for metric keys. Defaults to "eval".



        Returns:

            Dict[str, float]: The evaluation metrics.

        """
        
        # Start tracking memory usage
        self._memory_tracker.start()
        
        # Set eval_dataset and eval_dataloader
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        
        # Set eval_examples
        eval_examples = self.eval_examples if eval_examples is None else eval_examples
        
        # Start timing
        start_time = time.time()
        
        # Set compute_metrics
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        
        # Set eval_loop
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        try:
            # Run evaluation loop
            output = eval_loop(
                eval_dataloader,
                description="Evaluation",
                # Only gather predictions if there are metrics to compute
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            # Restore compute_metrics
            self.compute_metrics = compute_metrics
        
        # Set eval_dataset format
        if isinstance(eval_dataset, datasets.Dataset):
            eval_dataset.set_format(
                type=eval_dataset.format["type"],
                columns=list(eval_dataset.features.keys()),
            )
            
        # Postprocess output
        eval_preds = self.postprocess(output, eval_examples, eval_dataset, mode="evaluate")
        
        # Compute metrics
        metrics = {}
        if self.compute_metrics is not None:
            metrics = self.compute_metrics(eval_preds)
            
            # Make metrics JSON-serializable
            metrics = denumpify_detensorize(metrics)
            
            # Prefix all keys with metric_key_prefix + '_'
            for key in list(metrics.keys()):
                if not key.startswith(f"{metric_key_prefix}_"):
                    metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
            
            # Add speed metrics
            total_batch_size = self.args.eval_batch_size * self.args.world_size
            metrics.update(
                speed_metrics(
                    metric_key_prefix,
                    start_time,
                    num_samples=output.num_samples,
                    num_steps=math.ceil(output.num_samples / total_batch_size),
                )
            )
            
            # Log metrics
            self.log(metrics)
            
        # Log and save evaluation results
        filename = "eval_results.txt"
        eval_result_file = self.name + '_' + filename if self.name else filename
        with open(os.path.join(self.args.output_dir, eval_result_file), "w") as writer:
            logger.info(f"***** Eval results *****")
            writer.write("***** Eval results *****\n")
            writer.write(f"{datetime.now()}")
            for key in sorted(metrics.keys()):
                logger.info(f"  {key} = {metrics[key]}")
                writer.write(f"{key} = {metrics[key]}\n")
            writer.write("\n")
            
        # if DebugOption.TPU_METRICS_DEBUG and is_torch_xla_available():
        #     # Log debug metrics for PyTorch/XLA
        #     xm.master_print(met.metrics_report())
        
        # Call callback handler on evaluate
        self.control = self.callback_handler.on_evaluate(
            self.args, self.state, self.control, metrics
        )
        
        # Stop tracking memory usage and update metrics
        self._memory_tracker.stop_and_update_metrics(metrics)
        
        return metrics
            
    def predict(

        self,

        test_dataset: datasets.Dataset,

        test_examples: Optional[datasets.Dataset] = None,

        ignore_keys: Optional[List[str]] = None,

        metric_key_prefix: str = "test",

        mode: bool = "predict",

    ) -> PredictionOutput:
        """

        Predicts on the given test dataset and returns the predictions.



        Args:

            test_dataset (datasets.Dataset): The test dataset.

            test_examples (Optional[datasets.Dataset], optional): The test examples. Defaults to None.

            ignore_keys (Optional[List[str]], optional): Keys to ignore when calculating metrics. Defaults to None.

            metric_key_prefix (str, optional): The prefix for metric keys. Defaults to "test".

            mode (bool, optional): The mode of prediction. Defaults to "predict".



        Returns:

            PredictionOutput: The predictions.

        """
        
        # Start tracking memory usage
        self._memory_tracker.start()
        
        # Get the test dataloader
        test_dataloader = self.get_test_dataloader(test_dataset)
        start_time = time.time()
        
        # Set compute_metrics to None and store it for later use
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        
        # Get the evaluation loop
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        try:
            # Run the evaluation loop
            output = eval_loop(
                test_dataloader,
                description="Prediction",
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            # Reset compute_metrics to its original value
            self.compute_metrics = compute_metrics
        
        # If the test dataset is a datasets.Dataset, set its format
        if isinstance(test_dataset, datasets.Dataset):
            test_dataset.set_format(
                type=test_dataset.format["type"],
                columns=list(test_dataset.features.keys()),
            )
            
        # Postprocess the output and return the predictions
        predictions = self.postprocess(output, test_examples, test_dataset, mode=mode)
        
        # Stop tracking memory usage and update metrics
        self._memory_tracker.stop_and_update_metrics(output.metrics)
        
        return predictions