In [None]:
import json
import math
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import Dataset
from sklearn.metrics import f1_score, accuracy_score, log_loss
from tqdm import tqdm

from models.models import language_to_models

In [None]:
en = "en"
ru = "ru"
datasets_dir = Path("datasets")
test_filename = "arxiv_test"
test_dataset_filename = {
 en: datasets_dir / en / test_filename,
 ru: datasets_dir / ru / test_filename,
}

In [None]:
test_datasets = {}
for lang in (en, ru):
 csv_file = str(test_dataset_filename[lang]) + ".csv"
 json_file = str(test_dataset_filename[lang]) + ".json"
 if Path(csv_file).exists():
 test_datasets[lang] = pd.read_csv(csv_file)
 else:
 test_datasets[lang] = pd.read_json(json_file, lines=True)

In [None]:
test_results_filename = Path("test_results.json")
if test_results_filename.exists():
 with open(test_results_filename, "r") as f:
 test_results = json.load(f)
else:
 test_results = {}

In [None]:
def pred_to_1d(pred):
 return pred.idxmax(axis=1)


def true_to_nd(true, columns):
 columns = list(columns)
 true_arr = np.zeros((len(true), len(columns)))
 column_numbers = true.apply(lambda label: columns.index(label)).to_numpy()
 one_inds = np.column_stack((np.arange(len(true)), column_numbers))
 true_arr[one_inds] = 1
 true = pd.DataFrame(true_arr, columns=columns)
 return true


def accuracy(pred, true):
 return accuracy_score(true, pred_to_1d(pred))


def f1(pred, true):
 return f1_score(true, pred_to_1d(pred), average="macro")


def cross_entropy(pred, true):
 pred = pd.DataFrame(
 pred.to_numpy() / pred.sum(axis=1).to_numpy()[:, None], columns=pred.columns
 )
 return log_loss(true_to_nd(true, pred.columns), pred)

In [None]:
metrics = {"Macro F1": f1, "Accuracy": accuracy, "Cross-entropy loss": cross_entropy}

In [None]:
predications_dir = Path("pred")
predications_dir.mkdir(exist_ok=True)

In [None]:
def canonicalize_label(label):
 if "." in label:
 return label[: label.index(".")]
 return label


def predict(model_name, model, dataset: pd.DataFrame, batch_size=32, first: int = 3000):
 label = "category"
 all_labels = list(dataset[label].unique())
 if first is not None:
 dataset = dataset[:first]
 true = dataset[label]
 prediction_file_path = predications_dir / (model_name + ".csv")
 dataset_size = len(dataset)
 if not prediction_file_path.exists():
 preds = []
 for i in tqdm(
 range(0, dataset_size + batch_size, batch_size),
 desc=f"Predicting using {model_name}",
 total=math.ceil(dataset_size / batch_size),
 unit="batch",
 ):
 data = dataset.iloc[i : i + batch_size]
 if data.empty:
 break
 data = Dataset.from_pandas(data)
 batch_pred = model(data)
 batch_pred_canonicalised = []
 for paper_pred in batch_pred:
 labels_dict = {}
 for label_score in paper_pred:
 label = canonicalize_label(label_score["label"])
 if label not in all_labels:
 return None, None
 labels_dict[label] = label_score["score"]
 batch_pred_canonicalised.append(labels_dict)
 preds.extend(batch_pred_canonicalised)
 else:
 preds = pd.read_csv(prediction_file_path, index_col=0)
 preds = pd.DataFrame(preds).fillna(0)
 for label in all_labels:
 if label not in preds.columns:
 preds[label] = 0
 preds = preds.reindex(sorted(preds.columns), axis=1)
 if not prediction_file_path.exists():
 preds.to_csv(prediction_file_path)
 return preds, true


for lang, name_get_model in language_to_models.items():
 lang_results = test_results.setdefault(lang, {})
 for metric_name, metic in metrics.items():
 metrics_results = lang_results.setdefault(metric_name, {})
 for model_name, get_model in name_get_model.items():
 model_name = model_name.replace("/", ".")
 if model_name not in metrics_results:
 test_size = 3000 if en == lang else 500
 pred, true = predict(model_name, get_model(), test_datasets[lang], first=test_size)
 if pred is None:
 print(f"{model_name} does not produce labels that we can estimate")
 continue
 metrics_results[model_name] = metic(pred, true)
 print(f"{metric_name} for {model_name} = {metrics_results[model_name]}")

In [None]:
with open(test_results_filename, "w") as f:
 json.dump(test_results, f)