import json import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from loguru import logger from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from shad_mlops_transformers.config import config from shad_mlops_transformers.model import DocumentClassifier class ArxivDataset(Dataset): def __init__(self, raw_data: list[dict], class_mapper: dict[str, int] | None = None): """Разово вычитываем и сохраняем весь датасет.""" logger.info("reading data") self.x = [] self.y = [] # self.data = [] whitelist_labels = ["math", "cs", "stat"] i = 0 if class_mapper is None: self.class_mapper = {} else: self.class_mapper = class_mapper for item in raw_data: tmp_y = [] # да простят мне это потомки, но там зачем-то люди засунули питоновский dict в строку! for tag_desc in eval(item["tag"].replace("'", '"')): real_tag: str = tag_desc["term"] # пока берем только теги из whitelist if not any([real_tag.startswith(x) for x in whitelist_labels]): continue if class_mapper is None and real_tag not in self.class_mapper: self.class_mapper[real_tag] = i i += 1 tmp_y.append(self.class_mapper[real_tag]) # берем только один тег break # если был хотя бы один валидный тег, добавляем в датасет if len(tmp_y): # NOTE берем только один тег # self.data.append({"label": tmp_y[0], "text": item["summary"]}) self.x.append(item["summary"]) self.y.append(tmp_y[0]) self.classes = sorted(list(self.class_mapper.keys())) logger.info("[Done] reading data") def __getitem__(self, i): # return self.data[i] return self.x[i], self.y[i] def __len__(self): # return len(self.data) return len(self.x) def make_train_val(): with open(config.raw_data_dir / "arxivData.json", "r") as f: _raw_json = json.load(f) return train_test_split(_raw_json, test_size=config.test_size, shuffle=True, random_state=config.random_seed) def run_epoch(model: DocumentClassifier, optimizer: torch.optim.Optimizer, loader: DataLoader, criterion, device): model.to(device) model.train() losses_tr = [] for text, true_label in tqdm(loader): true_label = true_label.to(device) optimizer.zero_grad() pred = model(text) loss = criterion(pred, true_label) loss.backward() optimizer.step() current_loss = loss.item() # logger.debug(f"current loss: {current_loss}") losses_tr.append(current_loss) # break return model, optimizer, np.mean(losses_tr) def val(model, loader, criterion, target_p: float = 0.95, device: torch.device = torch.device("cpu")): model.eval() losses_val = [] with torch.no_grad(): for text, true_label in tqdm(loader): true_label = true_label.to(device) pred = model(text) loss = criterion(pred, true_label) losses_val.append(loss.item()) # break return np.mean(losses_val), None def train_loop( model: DocumentClassifier, optimizer: torch.optim.Optimizer, train_loader: DataLoader, val_loader: DataLoader, criterion, scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau, device, val_every: int = 1, ): losses = {"train": [], "val": []} best_val_loss = np.Inf metrics = {} for epoch in range(1, config.epochs + 1): logger.info(f"#{epoch}/{config.epochs}:") model, optimizer, loss = run_epoch( model=model, optimizer=optimizer, loader=train_loader, criterion=criterion, device=device ) losses["train"].append(loss) if not (epoch % val_every): loss, metrics_ = val(model, val_loader, criterion, device=device) losses["val"].append(loss) if metrics_ is not None: for name, value in metrics_.items(): metrics[name].append(value) # Сохраняем лучшую по валидации модель if loss < best_val_loss: config.checkpoints_folders.mkdir(parents=True, exist_ok=True) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "losses": losses, }, config.checkpoints_folders / f"epoch_{epoch}.pt", ) best_val_loss = loss scheduler.step(loss) fig, ax = plt.subplots(1, 1, figsize=(16, 9)) ax.plot(losses["train"], "r.-", label="train") ax.plot(losses["val"], "g.-", label="val") ax.grid(True) ax.legend() config.plots_dir.mkdir(exist_ok=True, parents=True) fig.savefig(config.plots_dir / "train.png") def collator(x): return x[0] def save_model(model: DocumentClassifier): config.weights_path.parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), config.weights_path) def load_mapper(): path = (config.raw_data_dir / "mapper.json").absolute() logger.info(f"opening mapper in path: {path}") with open(path, "r") as f: return json.load(f) def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"using device {device}") train, val = make_train_val() dataset_full = ArxivDataset(train + val) # только для вычисления маппинга cm = dataset_full.class_mapper logger.info("writing global class mapper to json") with open(config.raw_data_dir / "mapper.json", "w") as f: json.dump(cm, f) logger.info("[Done] writing global class mapper to json") del dataset_full dataset_train = ArxivDataset(train, class_mapper=cm) dataset_val = ArxivDataset(val, class_mapper=cm) loader_train = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True, drop_last=True) loader_val = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=True, drop_last=True) model = DocumentClassifier(n_classes=len(dataset_train.classes), device=device) optimizer = torch.optim.Adam(model.trainable_params) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.25, patience=4, threshold=0.001, verbose=True ) loss = nn.CrossEntropyLoss() logger.info("running train loop") train_loop( model=model, optimizer=optimizer, train_loader=loader_train, val_loader=loader_val, criterion=loss, scheduler=scheduler, device=device, val_every=1, ) save_model(model) if __name__ == "__main__": main()