File size: 7,503 Bytes
9e4713f
 
743cb9b
9e4713f
 
 
 
 
 
 
 
 
 
 
 
 
743cb9b
9e4713f
 
 
 
 
03dd850
9e4713f
743cb9b
 
 
 
9e4713f
 
 
 
 
 
 
 
743cb9b
9e4713f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743cb9b
 
9e4713f
 
 
743cb9b
9e4713f
 
 
 
 
 
743cb9b
 
 
 
9e4713f
 
 
 
743cb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e4713f
 
 
 
 
 
 
 
 
743cb9b
 
 
 
 
 
 
9e4713f
743cb9b
 
9e4713f
743cb9b
 
 
 
 
 
 
 
 
9e4713f
 
 
743cb9b
9e4713f
743cb9b
 
 
9e4713f
743cb9b
 
 
 
 
 
 
 
 
 
 
 
9e4713f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import json

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from loguru import logger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from shad_mlops_transformers.config import config
from shad_mlops_transformers.model import DocumentClassifier


class ArxivDataset(Dataset):
    def __init__(self, raw_data: list[dict], class_mapper: dict[str, int] | None = None):
        """Разово вычитываем и сохраняем весь датасет."""
        logger.info("reading data")
        self.x = []
        self.y = []
        # self.data = []
        whitelist_labels = ["math", "cs", "stat"]
        i = 0
        if class_mapper is None:
            self.class_mapper = {}
        else:
            self.class_mapper = class_mapper
        for item in raw_data:
            tmp_y = []
            # да простят мне это потомки, но там зачем-то люди засунули питоновский dict в строку!
            for tag_desc in eval(item["tag"].replace("'", '"')):
                real_tag: str = tag_desc["term"]
                # пока берем только теги из whitelist
                if not any([real_tag.startswith(x) for x in whitelist_labels]):
                    continue
                if class_mapper is None and real_tag not in self.class_mapper:
                    self.class_mapper[real_tag] = i
                    i += 1
                tmp_y.append(self.class_mapper[real_tag])
                # берем только один тег
                break
            # если был хотя бы один валидный тег, добавляем в датасет
            if len(tmp_y):
                # NOTE берем только один тег
                # self.data.append({"label": tmp_y[0], "text": item["summary"]})
                self.x.append(item["summary"])
                self.y.append(tmp_y[0])
        self.classes = sorted(list(self.class_mapper.keys()))
        logger.info("[Done] reading data")

    def __getitem__(self, i):
        # return self.data[i]
        return self.x[i], self.y[i]

    def __len__(self):
        # return len(self.data)
        return len(self.x)


def make_train_val():
    with open(config.raw_data_dir / "arxivData.json", "r") as f:
        _raw_json = json.load(f)
    return train_test_split(_raw_json, test_size=config.test_size, shuffle=True, random_state=config.random_seed)


def run_epoch(model: DocumentClassifier, optimizer: torch.optim.Optimizer, loader: DataLoader, criterion, device):
    model.to(device)
    model.train()
    losses_tr = []
    for text, true_label in tqdm(loader):
        true_label = true_label.to(device)
        optimizer.zero_grad()
        pred = model(text)
        loss = criterion(pred, true_label)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        # logger.debug(f"current loss: {current_loss}")
        losses_tr.append(current_loss)
        # break

    return model, optimizer, np.mean(losses_tr)


def val(model, loader, criterion, target_p: float = 0.95, device: torch.device = torch.device("cpu")):
    model.eval()
    losses_val = []
    with torch.no_grad():
        for text, true_label in tqdm(loader):
            true_label = true_label.to(device)
            pred = model(text)
            loss = criterion(pred, true_label)
            losses_val.append(loss.item())
            # break

    return np.mean(losses_val), None


def train_loop(
    model: DocumentClassifier,
    optimizer: torch.optim.Optimizer,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion,
    scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau,
    device,
    val_every: int = 1,
):
    losses = {"train": [], "val": []}
    best_val_loss = np.Inf
    metrics = {}
    for epoch in range(1, config.epochs + 1):
        logger.info(f"#{epoch}/{config.epochs}:")
        model, optimizer, loss = run_epoch(
            model=model, optimizer=optimizer, loader=train_loader, criterion=criterion, device=device
        )
        losses["train"].append(loss)
        if not (epoch % val_every):
            loss, metrics_ = val(model, val_loader, criterion, device=device)
            losses["val"].append(loss)
            if metrics_ is not None:
                for name, value in metrics_.items():
                    metrics[name].append(value)

            # Сохраняем лучшую по валидации модель
            if loss < best_val_loss:
                config.checkpoints_folders.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "epoch": epoch,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "losses": losses,
                    },
                    config.checkpoints_folders / f"epoch_{epoch}.pt",
                )
                best_val_loss = loss

            scheduler.step(loss)
            fig, ax = plt.subplots(1, 1, figsize=(16, 9))
            ax.plot(losses["train"], "r.-", label="train")
            ax.plot(losses["val"], "g.-", label="val")
            ax.grid(True)
            ax.legend()
            config.plots_dir.mkdir(exist_ok=True, parents=True)
            fig.savefig(config.plots_dir / "train.png")


def collator(x):
    return x[0]


def save_model(model: DocumentClassifier):
    config.weights_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), config.weights_path)


def load_mapper():
    path = (config.raw_data_dir / "mapper.json").absolute()
    logger.info(f"opening mapper in path: {path}")
    with open(path, "r") as f:
        return json.load(f)


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"using device {device}")
    train, val = make_train_val()
    dataset_full = ArxivDataset(train + val)  # только для вычисления маппинга
    cm = dataset_full.class_mapper
    logger.info("writing global class mapper to json")
    with open(config.raw_data_dir / "mapper.json", "w") as f:
        json.dump(cm, f)
    logger.info("[Done] writing global class mapper to json")
    del dataset_full
    dataset_train = ArxivDataset(train, class_mapper=cm)
    dataset_val = ArxivDataset(val, class_mapper=cm)
    loader_train = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True, drop_last=True)
    loader_val = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=True, drop_last=True)

    model = DocumentClassifier(n_classes=len(dataset_train.classes), device=device)
    optimizer = torch.optim.Adam(model.trainable_params)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.25, patience=4, threshold=0.001, verbose=True
    )
    loss = nn.CrossEntropyLoss()

    logger.info("running train loop")
    train_loop(
        model=model,
        optimizer=optimizer,
        train_loader=loader_train,
        val_loader=loader_val,
        criterion=loss,
        scheduler=scheduler,
        device=device,
        val_every=1,
    )
    save_model(model)


if __name__ == "__main__":
    main()