|
import logging |
|
import random |
|
from collections import Counter, defaultdict |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from scipy.stats import chisquare, ranksums |
|
from sklearn.metrics import accuracy_score, f1_score |
|
|
|
from . import perturber_utils as pu |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict): |
|
data = data.shuffle(seed=42) |
|
num_cells = len(data) |
|
|
|
if max_ncells is not None: |
|
if num_cells > max_ncells: |
|
data = data.select([i for i in range(max_ncells)]) |
|
if max_ncells_per_class is not None: |
|
class_labels = data[cell_state_dict["state_key"]] |
|
random.seed(42) |
|
subsample_indices = subsample_by_class(class_labels, max_ncells_per_class) |
|
data = data.select(subsample_indices) |
|
return data |
|
|
|
|
|
|
|
def subsample_by_class(labels, N): |
|
label_indices = defaultdict(list) |
|
|
|
for idx, label in enumerate(labels): |
|
label_indices[label].append(idx) |
|
selected_indices = [] |
|
|
|
for label, indices in label_indices.items(): |
|
if len(indices) > N: |
|
selected_indices.extend(random.sample(indices, N)) |
|
else: |
|
selected_indices.extend(indices) |
|
return selected_indices |
|
|
|
|
|
def rename_cols(data, state_key): |
|
data = data.rename_column(state_key, "label") |
|
return data |
|
|
|
|
|
def validate_and_clean_cols(train_data, eval_data, classifier): |
|
|
|
if classifier == "cell": |
|
label_col = "label" |
|
elif classifier == "gene": |
|
label_col = "labels" |
|
|
|
cols_to_keep = [label_col] + ["input_ids", "length"] |
|
if label_col not in train_data.column_names: |
|
logger.error(f"train_data must contain column {label_col} with class labels.") |
|
raise |
|
else: |
|
train_data = remove_cols(train_data, cols_to_keep) |
|
|
|
if eval_data is not None: |
|
if label_col not in eval_data.column_names: |
|
logger.error( |
|
f"eval_data must contain column {label_col} with class labels." |
|
) |
|
raise |
|
else: |
|
eval_data = remove_cols(eval_data, cols_to_keep) |
|
return train_data, eval_data |
|
|
|
|
|
def remove_cols(data, cols_to_keep): |
|
other_cols = list(data.features.keys()) |
|
other_cols = [ele for ele in other_cols if ele not in cols_to_keep] |
|
data = data.remove_columns(other_cols) |
|
return data |
|
|
|
|
|
def remove_rare(data, rare_threshold, label, nproc): |
|
if rare_threshold > 0: |
|
total_cells = len(data) |
|
label_counter = Counter(data[label]) |
|
nonrare_label_dict = { |
|
label: [k for k, v in label_counter if (v / total_cells) > rare_threshold] |
|
} |
|
data = pu.filter_by_dict(data, nonrare_label_dict, nproc) |
|
return data |
|
|
|
|
|
def label_classes(classifier, data, gene_class_dict, nproc): |
|
if classifier == "cell": |
|
label_set = set(data["label"]) |
|
elif classifier == "gene": |
|
|
|
def if_contains_label(example): |
|
a = pu.flatten_list(gene_class_dict.values()) |
|
b = example["input_ids"] |
|
return not set(a).isdisjoint(b) |
|
|
|
data = data.filter(if_contains_label, num_proc=nproc) |
|
label_set = gene_class_dict.keys() |
|
|
|
if len(data) == 0: |
|
logger.error( |
|
"No cells remain after filtering for target genes. Check target gene list." |
|
) |
|
raise |
|
|
|
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) |
|
id_class_dict = {v: k for k, v in class_id_dict.items()} |
|
|
|
def classes_to_ids(example): |
|
if classifier == "cell": |
|
example["label"] = class_id_dict[example["label"]] |
|
elif classifier == "gene": |
|
example["labels"] = label_gene_classes( |
|
example, class_id_dict, gene_class_dict |
|
) |
|
return example |
|
|
|
data = data.map(classes_to_ids, num_proc=nproc) |
|
return data, id_class_dict |
|
|
|
|
|
def label_gene_classes(example, class_id_dict, gene_class_dict): |
|
return [ |
|
class_id_dict.get(gene_class_dict.get(token_id, -100), -100) |
|
for token_id in example["input_ids"] |
|
] |
|
|
|
|
|
def prep_gene_classifier_split( |
|
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc |
|
): |
|
|
|
targets = np.array(targets) |
|
labels = np.array(labels) |
|
targets_train, targets_eval = targets[train_index], targets[eval_index] |
|
labels_train, labels_eval = labels[train_index], labels[eval_index] |
|
label_dict_train = dict(zip(targets_train, labels_train)) |
|
label_dict_eval = dict(zip(targets_eval, labels_eval)) |
|
|
|
|
|
def if_contains_train_label(example): |
|
a = targets_train |
|
b = example["input_ids"] |
|
return not set(a).isdisjoint(b) |
|
|
|
def if_contains_eval_label(example): |
|
a = targets_eval |
|
b = example["input_ids"] |
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
logger.info(f"Filtering training data for genes in split {iteration_num}") |
|
train_data = data.filter(if_contains_train_label, num_proc=num_proc) |
|
logger.info( |
|
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n" |
|
) |
|
logger.info(f"Filtering evalation data for genes in split {iteration_num}") |
|
eval_data = data.filter(if_contains_eval_label, num_proc=num_proc) |
|
logger.info( |
|
f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n" |
|
) |
|
|
|
|
|
train_data = downsample_and_shuffle(train_data, max_ncells, None, None) |
|
eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None) |
|
|
|
|
|
def train_classes_to_ids(example): |
|
example["labels"] = [ |
|
label_dict_train.get(token_id, -100) for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
def eval_classes_to_ids(example): |
|
example["labels"] = [ |
|
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc) |
|
eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc) |
|
|
|
return train_data, eval_data |
|
|
|
|
|
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc): |
|
targets = np.array(targets) |
|
labels = np.array(labels) |
|
label_dict_train = dict(zip(targets, labels)) |
|
|
|
|
|
def if_contains_train_label(example): |
|
a = targets |
|
b = example["input_ids"] |
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
logger.info("Filtering training data for genes to classify.") |
|
train_data = data.filter(if_contains_train_label, num_proc=num_proc) |
|
logger.info( |
|
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n" |
|
) |
|
|
|
|
|
train_data = downsample_and_shuffle(train_data, max_ncells, None, None) |
|
|
|
|
|
def train_classes_to_ids(example): |
|
example["labels"] = [ |
|
label_dict_train.get(token_id, -100) for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc) |
|
|
|
return train_data |
|
|
|
|
|
def balance_attr_splits( |
|
data, |
|
attr_to_split, |
|
attr_to_balance, |
|
eval_size, |
|
max_trials, |
|
pval_threshold, |
|
state_key, |
|
nproc, |
|
): |
|
metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]}) |
|
for attr in attr_to_balance: |
|
if attr == state_key: |
|
metadata_df[attr] = data["label"] |
|
else: |
|
metadata_df[attr] = data[attr] |
|
metadata_df = metadata_df.drop_duplicates() |
|
|
|
split_attr_ids = list(metadata_df["split_attr_ids"]) |
|
assert len(split_attr_ids) == len(set(split_attr_ids)) |
|
eval_num = round(len(split_attr_ids) * eval_size) |
|
colnames = ( |
|
["trial_num", "train_ids", "eval_ids"] |
|
+ pu.flatten_list( |
|
[ |
|
[ |
|
f"{attr}_train_mean_or_counts", |
|
f"{attr}_eval_mean_or_counts", |
|
f"{attr}_pval", |
|
] |
|
for attr in attr_to_balance |
|
] |
|
) |
|
+ ["mean_pval"] |
|
) |
|
balance_df = pd.DataFrame(columns=colnames) |
|
data_dict = dict() |
|
trial_num = 1 |
|
for i in range(max_trials): |
|
if not all( |
|
count > 1 for count in list(Counter(metadata_df[state_key]).values()) |
|
): |
|
logger.error( |
|
f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. " |
|
) |
|
raise |
|
eval_base = [] |
|
for state in set(metadata_df[state_key]): |
|
eval_base += list( |
|
metadata_df.loc[ |
|
metadata_df[state_key][metadata_df[state_key].eq(state)] |
|
.sample(1, random_state=i) |
|
.index |
|
]["split_attr_ids"] |
|
) |
|
non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base] |
|
random.seed(i) |
|
eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base |
|
train_ids = [idx for idx in split_attr_ids if idx not in eval_ids] |
|
df_vals = [trial_num, train_ids, eval_ids] |
|
pvals = [] |
|
for attr in attr_to_balance: |
|
train_attr = list( |
|
metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr] |
|
) |
|
eval_attr = list( |
|
metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr] |
|
) |
|
if attr == state_key: |
|
|
|
train_attr = [str(item) for item in train_attr] |
|
eval_attr = [str(item) for item in eval_attr] |
|
if all(isinstance(item, (int, float)) for item in train_attr + eval_attr): |
|
train_attr_mean = np.nanmean(train_attr) |
|
eval_attr_mean = np.nanmean(eval_attr) |
|
pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue |
|
df_vals += [train_attr_mean, eval_attr_mean, pval] |
|
elif all(isinstance(item, (str)) for item in train_attr + eval_attr): |
|
obs_counts = Counter(train_attr) |
|
exp_counts = Counter(eval_attr) |
|
all_categ = set(obs_counts.keys()).union(set(exp_counts.keys())) |
|
obs = [obs_counts[cat] for cat in all_categ] |
|
exp = [ |
|
exp_counts[cat] * sum(obs) / sum(exp_counts.values()) |
|
for cat in all_categ |
|
] |
|
chisquare(f_obs=obs, f_exp=exp).pvalue |
|
train_attr_counts = str(obs_counts).strip("Counter(").strip(")") |
|
eval_attr_counts = str(exp_counts).strip("Counter(").strip(")") |
|
df_vals += [train_attr_counts, eval_attr_counts, pval] |
|
else: |
|
logger.error( |
|
f"Inconsistent data types in attribute {attr}. " |
|
"Cannot infer if continuous or categorical. " |
|
"Must be all numeric (continuous) or all strings (categorical) to balance." |
|
) |
|
raise |
|
pvals += [pval] |
|
|
|
df_vals += [np.nanmean(pvals)] |
|
balance_df_i = pd.DataFrame(df_vals, index=colnames).T |
|
balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True) |
|
valid_pvals = [ |
|
pval_i |
|
for pval_i in pvals |
|
if isinstance(pval_i, (int, float)) and not np.isnan(pval_i) |
|
] |
|
if all(i >= pval_threshold for i in valid_pvals): |
|
data_dict["train"] = pu.filter_by_dict( |
|
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc |
|
) |
|
data_dict["test"] = pu.filter_by_dict( |
|
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc |
|
) |
|
return data_dict, balance_df |
|
trial_num = trial_num + 1 |
|
balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :] |
|
data_dict["train"] = pu.filter_by_dict( |
|
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc |
|
) |
|
data_dict["test"] = pu.filter_by_dict( |
|
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc |
|
) |
|
logger.warning( |
|
f"No splits found without significant difference in attr_to_balance among {max_trials} trials. " |
|
f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials." |
|
) |
|
return data_dict, balance_df |
|
|
|
|
|
def get_num_classes(id_class_dict): |
|
return len(set(id_class_dict.values())) |
|
|
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
|
|
|
|
if len(labels.shape) == 1: |
|
acc = accuracy_score(labels, preds) |
|
macro_f1 = f1_score(labels, preds, average="macro") |
|
else: |
|
flat_labels = labels.flatten().tolist() |
|
flat_preds = preds.flatten().tolist() |
|
logit_label_paired = [ |
|
item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100 |
|
] |
|
y_pred = [item[0] for item in logit_label_paired] |
|
y_true = [item[1] for item in logit_label_paired] |
|
|
|
acc = accuracy_score(y_true, y_pred) |
|
macro_f1 = f1_score(y_true, y_pred, average="macro") |
|
|
|
return {"accuracy": acc, "macro_f1": macro_f1} |
|
|
|
|
|
def get_default_train_args(model, classifier, data, output_dir): |
|
num_layers = pu.quant_layers(model) |
|
freeze_layers = 0 |
|
batch_size = 12 |
|
if classifier == "cell": |
|
epochs = 10 |
|
evaluation_strategy = "epoch" |
|
load_best_model_at_end = True |
|
else: |
|
epochs = 1 |
|
evaluation_strategy = "no" |
|
load_best_model_at_end = False |
|
|
|
if num_layers == 6: |
|
default_training_args = { |
|
"learning_rate": 5e-5, |
|
"lr_scheduler_type": "linear", |
|
"warmup_steps": 500, |
|
"per_device_train_batch_size": batch_size, |
|
"per_device_eval_batch_size": batch_size, |
|
} |
|
else: |
|
default_training_args = { |
|
"per_device_train_batch_size": batch_size, |
|
"per_device_eval_batch_size": batch_size, |
|
} |
|
|
|
training_args = { |
|
"num_train_epochs": epochs, |
|
"do_train": True, |
|
"do_eval": True, |
|
"evaluation_strategy": evaluation_strategy, |
|
"logging_steps": np.floor(len(data) / batch_size / 8), |
|
"save_strategy": "epoch", |
|
"group_by_length": False, |
|
"length_column_name": "length", |
|
"disable_tqdm": False, |
|
"weight_decay": 0.001, |
|
"load_best_model_at_end": load_best_model_at_end, |
|
} |
|
training_args.update(default_training_args) |
|
|
|
return training_args, freeze_layers |
|
|