File size: 3,669 Bytes
933ca80
f07bfd7
933ca80
f07bfd7
933ca80
 
f07bfd7
933ca80
 
 
 
 
f07bfd7
 
 
933ca80
 
 
 
f07bfd7
 
933ca80
f07bfd7
 
933ca80
 
 
f07bfd7
 
 
933ca80
 
 
 
 
 
 
f07bfd7
933ca80
 
f07bfd7
933ca80
f07bfd7
 
 
 
933ca80
f07bfd7
933ca80
 
 
 
f07bfd7
933ca80
 
 
 
 
 
 
f07bfd7
933ca80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07bfd7
933ca80
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd

from .data import prepare_test_loader
from .imports import *
from .model import GeneformerMultiTask


def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
    task_pred_labels = {task_name: [] for task_name in config["task_names"]}
    task_pred_probs = {task_name: [] for task_name in config["task_names"]}
    cell_ids = []

    # # Load task label mappings from pickle file
    # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
    #     task_label_mappings = pickle.load(f)

    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            _, logits, _ = model(input_ids, attention_mask)
            for sample_idx in range(len(batch["input_ids"])):
                cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
                cell_ids.append(cell_id)
                for i, task_name in enumerate(config["task_names"]):
                    pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
                    pred_prob = (
                        torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
                    )
                    task_pred_labels[task_name].append(pred_label)
                    task_pred_probs[task_name].append(pred_prob)

    # Save test predictions with cell IDs and probabilities to CSV
    test_results_dir = config["results_dir"]
    os.makedirs(test_results_dir, exist_ok=True)
    test_preds_file = os.path.join(test_results_dir, "test_preds.csv")

    rows = []
    for sample_idx in range(len(cell_ids)):
        row = {"Cell ID": cell_ids[sample_idx]}
        for task_name in config["task_names"]:
            row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
            row[f"{task_name} Probabilities"] = ",".join(
                map(str, task_pred_probs[task_name][sample_idx])
            )
        rows.append(row)

    df = pd.DataFrame(rows)
    df.to_csv(test_preds_file, index=False)
    print(f"Test predictions saved to {test_preds_file}")


def load_and_evaluate_test_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
    model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
    hyperparams_path = os.path.join(model_directory, "hyperparameters.json")

    # Load the saved best hyperparameters
    with open(hyperparams_path, "r") as f:
        best_hyperparams = json.load(f)

    # Extract the task weights if present, otherwise set to None
    task_weights = best_hyperparams.get("task_weights", None)
    normalized_task_weights = task_weights if task_weights else []

    # Print the loaded hyperparameters
    print("Loaded hyperparameters:")
    for param, value in best_hyperparams.items():
        if param == "task_weights":
            print(f"normalized_task_weights: {value}")
        else:
            print(f"{param}: {value}")

    best_model_path = os.path.join(model_directory, "pytorch_model.bin")
    best_model = GeneformerMultiTask(
        config["pretrained_path"],
        num_labels_list,
        dropout_rate=best_hyperparams["dropout_rate"],
        use_task_weights=config["use_task_weights"],
        task_weights=normalized_task_weights,
    )
    best_model.load_state_dict(torch.load(best_model_path))
    best_model.to(device)

    evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
    print("Evaluation completed.")