alekeik1's picture
feat(model): add stats
03dd850
raw
history blame
7.5 kB
import json
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from loguru import logger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from shad_mlops_transformers.config import config
from shad_mlops_transformers.model import DocumentClassifier
class ArxivDataset(Dataset):
def __init__(self, raw_data: list[dict], class_mapper: dict[str, int] | None = None):
"""Разово вычитываем и сохраняем весь датасет."""
logger.info("reading data")
self.x = []
self.y = []
# self.data = []
whitelist_labels = ["math", "cs", "stat"]
i = 0
if class_mapper is None:
self.class_mapper = {}
else:
self.class_mapper = class_mapper
for item in raw_data:
tmp_y = []
# да простят мне это потомки, но там зачем-то люди засунули питоновский dict в строку!
for tag_desc in eval(item["tag"].replace("'", '"')):
real_tag: str = tag_desc["term"]
# пока берем только теги из whitelist
if not any([real_tag.startswith(x) for x in whitelist_labels]):
continue
if class_mapper is None and real_tag not in self.class_mapper:
self.class_mapper[real_tag] = i
i += 1
tmp_y.append(self.class_mapper[real_tag])
# берем только один тег
break
# если был хотя бы один валидный тег, добавляем в датасет
if len(tmp_y):
# NOTE берем только один тег
# self.data.append({"label": tmp_y[0], "text": item["summary"]})
self.x.append(item["summary"])
self.y.append(tmp_y[0])
self.classes = sorted(list(self.class_mapper.keys()))
logger.info("[Done] reading data")
def __getitem__(self, i):
# return self.data[i]
return self.x[i], self.y[i]
def __len__(self):
# return len(self.data)
return len(self.x)
def make_train_val():
with open(config.raw_data_dir / "arxivData.json", "r") as f:
_raw_json = json.load(f)
return train_test_split(_raw_json, test_size=config.test_size, shuffle=True, random_state=config.random_seed)
def run_epoch(model: DocumentClassifier, optimizer: torch.optim.Optimizer, loader: DataLoader, criterion, device):
model.to(device)
model.train()
losses_tr = []
for text, true_label in tqdm(loader):
true_label = true_label.to(device)
optimizer.zero_grad()
pred = model(text)
loss = criterion(pred, true_label)
loss.backward()
optimizer.step()
current_loss = loss.item()
# logger.debug(f"current loss: {current_loss}")
losses_tr.append(current_loss)
# break
return model, optimizer, np.mean(losses_tr)
def val(model, loader, criterion, target_p: float = 0.95, device: torch.device = torch.device("cpu")):
model.eval()
losses_val = []
with torch.no_grad():
for text, true_label in tqdm(loader):
true_label = true_label.to(device)
pred = model(text)
loss = criterion(pred, true_label)
losses_val.append(loss.item())
# break
return np.mean(losses_val), None
def train_loop(
model: DocumentClassifier,
optimizer: torch.optim.Optimizer,
train_loader: DataLoader,
val_loader: DataLoader,
criterion,
scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau,
device,
val_every: int = 1,
):
losses = {"train": [], "val": []}
best_val_loss = np.Inf
metrics = {}
for epoch in range(1, config.epochs + 1):
logger.info(f"#{epoch}/{config.epochs}:")
model, optimizer, loss = run_epoch(
model=model, optimizer=optimizer, loader=train_loader, criterion=criterion, device=device
)
losses["train"].append(loss)
if not (epoch % val_every):
loss, metrics_ = val(model, val_loader, criterion, device=device)
losses["val"].append(loss)
if metrics_ is not None:
for name, value in metrics_.items():
metrics[name].append(value)
# Сохраняем лучшую по валидации модель
if loss < best_val_loss:
config.checkpoints_folders.mkdir(parents=True, exist_ok=True)
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"losses": losses,
},
config.checkpoints_folders / f"epoch_{epoch}.pt",
)
best_val_loss = loss
scheduler.step(loss)
fig, ax = plt.subplots(1, 1, figsize=(16, 9))
ax.plot(losses["train"], "r.-", label="train")
ax.plot(losses["val"], "g.-", label="val")
ax.grid(True)
ax.legend()
config.plots_dir.mkdir(exist_ok=True, parents=True)
fig.savefig(config.plots_dir / "train.png")
def collator(x):
return x[0]
def save_model(model: DocumentClassifier):
config.weights_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), config.weights_path)
def load_mapper():
path = (config.raw_data_dir / "mapper.json").absolute()
logger.info(f"opening mapper in path: {path}")
with open(path, "r") as f:
return json.load(f)
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"using device {device}")
train, val = make_train_val()
dataset_full = ArxivDataset(train + val) # только для вычисления маппинга
cm = dataset_full.class_mapper
logger.info("writing global class mapper to json")
with open(config.raw_data_dir / "mapper.json", "w") as f:
json.dump(cm, f)
logger.info("[Done] writing global class mapper to json")
del dataset_full
dataset_train = ArxivDataset(train, class_mapper=cm)
dataset_val = ArxivDataset(val, class_mapper=cm)
loader_train = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True, drop_last=True)
loader_val = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=True, drop_last=True)
model = DocumentClassifier(n_classes=len(dataset_train.classes), device=device)
optimizer = torch.optim.Adam(model.trainable_params)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.25, patience=4, threshold=0.001, verbose=True
)
loss = nn.CrossEntropyLoss()
logger.info("running train loop")
train_loop(
model=model,
optimizer=optimizer,
train_loader=loader_train,
val_loader=loader_val,
criterion=loss,
scheduler=scheduler,
device=device,
val_every=1,
)
save_model(model)
if __name__ == "__main__":
main()