ctheodoris's picture
precommit formatting
f07bfd7
raw
history blame
5.48 kB
import os
from .collators import DataCollatorForMultitaskCellClassification
from .imports import *
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
try:
dataset = load_from_disk(dataset_path)
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
task_to_column = dict(zip(task_names, config["task_columns"]))
config["task_names"] = task_names
if not is_test:
available_columns = set(dataset.column_names)
for column in task_to_column.values():
if column not in available_columns:
raise KeyError(
f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
)
label_mappings = {}
task_label_mappings = {}
cell_id_mapping = {}
num_labels_list = []
# Load or create task label mappings
if not is_test:
for task, column in task_to_column.items():
unique_values = sorted(set(dataset[column])) # Ensure consistency
label_mappings[column] = {
label: idx for idx, label in enumerate(unique_values)
}
task_label_mappings[task] = label_mappings[column]
num_labels_list.append(len(unique_values))
# Print the mappings for each task with dataset type prefix
for task, mapping in task_label_mappings.items():
print(
f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
) # sanity check, for train/validation splits
# Save the task label mappings as a pickle file
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
pickle.dump(task_label_mappings, f)
else:
# Load task label mappings from pickle file for test data
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
task_label_mappings = pickle.load(f)
# Infer num_labels_list from task_label_mappings
for task, mapping in task_label_mappings.items():
num_labels_list.append(len(mapping))
# Store unique cell IDs in a separate dictionary
for idx, record in enumerate(dataset):
cell_id = record.get("unique_cell_id", idx)
cell_id_mapping[idx] = cell_id
# Transform records to the desired format
transformed_dataset = []
for idx, record in enumerate(dataset):
transformed_record = {}
transformed_record["input_ids"] = torch.tensor(
record["input_ids"], dtype=torch.long
)
# Use index-based cell ID for internal tracking
transformed_record["cell_id"] = idx
if not is_test:
# Prepare labels
label_dict = {}
for task, column in task_to_column.items():
label_value = record[column]
label_index = task_label_mappings[task][label_value]
label_dict[task] = label_index
transformed_record["label"] = label_dict
else:
# Create dummy labels for test data
label_dict = {task: -1 for task in config["task_names"]}
transformed_record["label"] = label_dict
transformed_dataset.append(transformed_record)
return transformed_dataset, cell_id_mapping, num_labels_list
except KeyError as e:
print(f"Missing configuration or dataset key: {e}")
except Exception as e:
print(f"An error occurred while loading or preprocessing data: {e}")
return None, None, None
def preload_and_process_data(config):
# Load and preprocess data once
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
config["train_path"], config, dataset_type="train"
)
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
config["val_path"], config, dataset_type="validation"
)
return (
train_dataset,
train_cell_id_mapping,
val_dataset,
val_cell_id_mapping,
num_labels_list,
)
def get_data_loader(preprocessed_dataset, batch_size):
nproc = os.cpu_count() ### I/O operations
data_collator = DataCollatorForMultitaskCellClassification()
loader = DataLoader(
preprocessed_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=nproc,
pin_memory=True,
)
return loader
def preload_data(config):
# Preprocessing the data before the Optuna trials start
train_loader = get_data_loader("train", config)
val_loader = get_data_loader("val", config)
return train_loader, val_loader
def load_and_preprocess_test_data(config):
"""
Load and preprocess test data, treating it as unlabeled.
"""
return load_and_preprocess_data(config["test_path"], config, is_test=True)
def prepare_test_loader(config):
"""
Prepare DataLoader for the test dataset.
"""
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
config
)
test_loader = get_data_loader(test_dataset, config["batch_size"])
return test_loader, cell_id_mapping, num_labels_list