Spaces:
Sleeping
Sleeping
import json | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from loguru import logger | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm import tqdm | |
from shad_mlops_transformers.config import config | |
from shad_mlops_transformers.model import DocumentClassifier | |
class ArxivDataset(Dataset): | |
def __init__(self, raw_data: list[dict], class_mapper: dict[str, int] | None = None): | |
"""Разово вычитываем и сохраняем весь датасет.""" | |
logger.info("reading data") | |
self.x = [] | |
self.y = [] | |
# self.data = [] | |
whitelist_labels = ["math", "cs", "stat"] | |
i = 0 | |
if class_mapper is None: | |
self.class_mapper = {} | |
else: | |
self.class_mapper = class_mapper | |
for item in raw_data: | |
tmp_y = [] | |
# да простят мне это потомки, но там зачем-то люди засунули питоновский dict в строку! | |
for tag_desc in eval(item["tag"].replace("'", '"')): | |
real_tag: str = tag_desc["term"] | |
# пока берем только теги из whitelist | |
if not any([real_tag.startswith(x) for x in whitelist_labels]): | |
continue | |
if class_mapper is None and real_tag not in self.class_mapper: | |
self.class_mapper[real_tag] = i | |
i += 1 | |
tmp_y.append(self.class_mapper[real_tag]) | |
# берем только один тег | |
break | |
# если был хотя бы один валидный тег, добавляем в датасет | |
if len(tmp_y): | |
# NOTE берем только один тег | |
# self.data.append({"label": tmp_y[0], "text": item["summary"]}) | |
self.x.append(item["summary"]) | |
self.y.append(tmp_y[0]) | |
self.classes = sorted(list(self.class_mapper.keys())) | |
logger.info("[Done] reading data") | |
def __getitem__(self, i): | |
# return self.data[i] | |
return self.x[i], self.y[i] | |
def __len__(self): | |
# return len(self.data) | |
return len(self.x) | |
def make_train_val(): | |
with open(config.raw_data_dir / "arxivData.json", "r") as f: | |
_raw_json = json.load(f) | |
return train_test_split(_raw_json, test_size=config.test_size, shuffle=True, random_state=config.random_seed) | |
def run_epoch(model: DocumentClassifier, optimizer: torch.optim.Optimizer, loader: DataLoader, criterion, device): | |
model.to(device) | |
model.train() | |
losses_tr = [] | |
for text, true_label in tqdm(loader): | |
true_label = true_label.to(device) | |
optimizer.zero_grad() | |
pred = model(text) | |
loss = criterion(pred, true_label) | |
loss.backward() | |
optimizer.step() | |
current_loss = loss.item() | |
# logger.debug(f"current loss: {current_loss}") | |
losses_tr.append(current_loss) | |
# break | |
return model, optimizer, np.mean(losses_tr) | |
def val(model, loader, criterion, target_p: float = 0.95, device: torch.device = torch.device("cpu")): | |
model.eval() | |
losses_val = [] | |
with torch.no_grad(): | |
for text, true_label in tqdm(loader): | |
true_label = true_label.to(device) | |
pred = model(text) | |
loss = criterion(pred, true_label) | |
losses_val.append(loss.item()) | |
# break | |
return np.mean(losses_val), None | |
def train_loop( | |
model: DocumentClassifier, | |
optimizer: torch.optim.Optimizer, | |
train_loader: DataLoader, | |
val_loader: DataLoader, | |
criterion, | |
scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau, | |
device, | |
val_every: int = 1, | |
): | |
losses = {"train": [], "val": []} | |
best_val_loss = np.Inf | |
metrics = {} | |
for epoch in range(1, config.epochs + 1): | |
logger.info(f"#{epoch}/{config.epochs}:") | |
model, optimizer, loss = run_epoch( | |
model=model, optimizer=optimizer, loader=train_loader, criterion=criterion, device=device | |
) | |
losses["train"].append(loss) | |
if not (epoch % val_every): | |
loss, metrics_ = val(model, val_loader, criterion, device=device) | |
losses["val"].append(loss) | |
if metrics_ is not None: | |
for name, value in metrics_.items(): | |
metrics[name].append(value) | |
# Сохраняем лучшую по валидации модель | |
if loss < best_val_loss: | |
config.checkpoints_folders.mkdir(parents=True, exist_ok=True) | |
torch.save( | |
{ | |
"epoch": epoch, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"scheduler_state_dict": scheduler.state_dict(), | |
"losses": losses, | |
}, | |
config.checkpoints_folders / f"epoch_{epoch}.pt", | |
) | |
best_val_loss = loss | |
scheduler.step(loss) | |
fig, ax = plt.subplots(1, 1, figsize=(16, 9)) | |
ax.plot(losses["train"], "r.-", label="train") | |
ax.plot(losses["val"], "g.-", label="val") | |
ax.grid(True) | |
ax.legend() | |
config.plots_dir.mkdir(exist_ok=True, parents=True) | |
fig.savefig(config.plots_dir / "train.png") | |
def collator(x): | |
return x[0] | |
def save_model(model: DocumentClassifier): | |
config.weights_path.parent.mkdir(parents=True, exist_ok=True) | |
torch.save(model.state_dict(), config.weights_path) | |
def load_mapper(): | |
path = (config.raw_data_dir / "mapper.json").absolute() | |
logger.info(f"opening mapper in path: {path}") | |
with open(path, "r") as f: | |
return json.load(f) | |
def main(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"using device {device}") | |
train, val = make_train_val() | |
dataset_full = ArxivDataset(train + val) # только для вычисления маппинга | |
cm = dataset_full.class_mapper | |
logger.info("writing global class mapper to json") | |
with open(config.raw_data_dir / "mapper.json", "w") as f: | |
json.dump(cm, f) | |
logger.info("[Done] writing global class mapper to json") | |
del dataset_full | |
dataset_train = ArxivDataset(train, class_mapper=cm) | |
dataset_val = ArxivDataset(val, class_mapper=cm) | |
loader_train = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True, drop_last=True) | |
loader_val = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=True, drop_last=True) | |
model = DocumentClassifier(n_classes=len(dataset_train.classes), device=device) | |
optimizer = torch.optim.Adam(model.trainable_params) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, mode="min", factor=0.25, patience=4, threshold=0.001, verbose=True | |
) | |
loss = nn.CrossEntropyLoss() | |
logger.info("running train loop") | |
train_loop( | |
model=model, | |
optimizer=optimizer, | |
train_loader=loader_train, | |
val_loader=loader_val, | |
criterion=loss, | |
scheduler=scheduler, | |
device=device, | |
val_every=1, | |
) | |
save_model(model) | |
if __name__ == "__main__": | |
main() | |