File size: 3,711 Bytes
933ca80
f07bfd7
eab1878
 
933ca80
 
f07bfd7
933ca80
 
 
 
 
f07bfd7
 
 
933ca80
 
 
 
f07bfd7
 
933ca80
f07bfd7
 
933ca80
 
 
f07bfd7
 
 
933ca80
 
 
 
 
 
 
f07bfd7
933ca80
 
f07bfd7
933ca80
f07bfd7
 
 
 
933ca80
f07bfd7
933ca80
 
 
 
f07bfd7
933ca80
 
 
 
 
 
 
f07bfd7
933ca80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07bfd7
933ca80
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd

from .imports import *  # noqa # isort:skip
from .data import prepare_test_loader  # noqa # isort:skip
from .model import GeneformerMultiTask


def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
    task_pred_labels = {task_name: [] for task_name in config["task_names"]}
    task_pred_probs = {task_name: [] for task_name in config["task_names"]}
    cell_ids = []

    # # Load task label mappings from pickle file
    # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
    #     task_label_mappings = pickle.load(f)

    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            _, logits, _ = model(input_ids, attention_mask)
            for sample_idx in range(len(batch["input_ids"])):
                cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
                cell_ids.append(cell_id)
                for i, task_name in enumerate(config["task_names"]):
                    pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
                    pred_prob = (
                        torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
                    )
                    task_pred_labels[task_name].append(pred_label)
                    task_pred_probs[task_name].append(pred_prob)

    # Save test predictions with cell IDs and probabilities to CSV
    test_results_dir = config["results_dir"]
    os.makedirs(test_results_dir, exist_ok=True)
    test_preds_file = os.path.join(test_results_dir, "test_preds.csv")

    rows = []
    for sample_idx in range(len(cell_ids)):
        row = {"Cell ID": cell_ids[sample_idx]}
        for task_name in config["task_names"]:
            row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
            row[f"{task_name} Probabilities"] = ",".join(
                map(str, task_pred_probs[task_name][sample_idx])
            )
        rows.append(row)

    df = pd.DataFrame(rows)
    df.to_csv(test_preds_file, index=False)
    print(f"Test predictions saved to {test_preds_file}")


def load_and_evaluate_test_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
    model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
    hyperparams_path = os.path.join(model_directory, "hyperparameters.json")

    # Load the saved best hyperparameters
    with open(hyperparams_path, "r") as f:
        best_hyperparams = json.load(f)

    # Extract the task weights if present, otherwise set to None
    task_weights = best_hyperparams.get("task_weights", None)
    normalized_task_weights = task_weights if task_weights else []

    # Print the loaded hyperparameters
    print("Loaded hyperparameters:")
    for param, value in best_hyperparams.items():
        if param == "task_weights":
            print(f"normalized_task_weights: {value}")
        else:
            print(f"{param}: {value}")

    best_model_path = os.path.join(model_directory, "pytorch_model.bin")
    best_model = GeneformerMultiTask(
        config["pretrained_path"],
        num_labels_list,
        dropout_rate=best_hyperparams["dropout_rate"],
        use_task_weights=config["use_task_weights"],
        task_weights=normalized_task_weights,
    )
    best_model.load_state_dict(torch.load(best_model_path))
    best_model.to(device)

    evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
    print("Evaluation completed.")