|
|
|
|
|
|
|
import math
|
|
import os, csv, json
|
|
import io, textwrap, itertools
|
|
import subprocess
|
|
from Bio import SeqIO
|
|
import torch
|
|
import numpy as np
|
|
import sys, random
|
|
from sklearn.metrics import confusion_matrix
|
|
import matplotlib.pyplot as plt
|
|
import pynvml, requests
|
|
from collections import OrderedDict
|
|
|
|
plt.rcParams.update({'font.size': 18})
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
from .file_operator import file_reader
|
|
from .multi_label_metrics import prob_2_pred, relevant_indexes, metrics_multi_label
|
|
from .metrics import metrics_multi_class, metrics_binary, metrics_regression
|
|
|
|
common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
|
|
|
|
|
|
|
|
common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
|
|
|
|
|
|
def to_device(device, batch):
|
|
'''
|
|
input to device
|
|
:param device:
|
|
:param batch:
|
|
:return:
|
|
'''
|
|
new_batch = {}
|
|
sample_num = 0
|
|
tens = None
|
|
for item1 in batch.items():
|
|
new_batch[item1[0]] = {}
|
|
if isinstance(item1[1], dict):
|
|
for item2 in item1[1].items():
|
|
new_batch[item1[0]][item2[0]] = {}
|
|
if isinstance(item2[1], dict):
|
|
for item3 in item2[1].items():
|
|
if item3[1] is not None and not isinstance(item3[1], int) and not isinstance(item3[1], str) and not isinstance(item3[1], float):
|
|
new_batch[item1[0]][item2[0]][item3[0]] = item3[1].to(device)
|
|
tens = item3[1]
|
|
else:
|
|
new_batch[item1[0]][item2[0]][item3[0]] = item3[1]
|
|
else:
|
|
if item2[1] is not None and not isinstance(item2[1], int) and not isinstance(item2[1], str) and not isinstance(item2[1], float):
|
|
new_batch[item1[0]][item2[0]] = item2[1].to(device)
|
|
tens = item2[1]
|
|
else:
|
|
new_batch[item1[0]][item2[0]] = item2[1]
|
|
else:
|
|
if item1[1] is not None and not isinstance(item1[1], int) and not isinstance(item1[1], str) and not isinstance(item1[1], float):
|
|
new_batch[item1[0]] = item1[1].to(device)
|
|
tens = item1[1]
|
|
else:
|
|
new_batch[item1[0]] = item1[1]
|
|
if tens is not None:
|
|
sample_num = tens.shape[0]
|
|
return new_batch, sample_num
|
|
|
|
|
|
def get_parameter_number(model):
|
|
'''
|
|
colc the parameter number of the model
|
|
:param model:
|
|
:return:
|
|
'''
|
|
param_size = 0
|
|
param_sum = 0
|
|
trainable_size = 0
|
|
trainable_num = 0
|
|
for param in model.parameters():
|
|
cur_size = param.nelement() * param.element_size()
|
|
cur_num = param.nelement()
|
|
param_size += cur_size
|
|
param_sum += cur_num
|
|
if param.requires_grad:
|
|
trainable_size += cur_size
|
|
trainable_num += cur_num
|
|
buffer_size = 0
|
|
buffer_sum = 0
|
|
for buffer in model.buffers():
|
|
buffer_size += buffer.nelement() * buffer.element_size()
|
|
buffer_sum += buffer.nelement()
|
|
'''
|
|
total_num = sum(p.numel() for p in model.parameters())
|
|
total_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
|
total_num += sum(p.numel() for p in model.buffers())
|
|
total_size += sum(p.numel() * p.element_size() for p in model.buffers())
|
|
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
trainable_size = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
|
|
'''
|
|
return {
|
|
'total_num': "%fM" % round((buffer_sum + param_sum)/(1024 * 1024), 2),
|
|
'total_size': "%fMB" % round((buffer_size + param_size)/(1024 * 1024), 2),
|
|
'param_sum': "%fM" % round(param_sum/(1024 * 1024), 2),
|
|
'param_size': "%fMB" % round(param_size/(1024 * 1024), 2),
|
|
'buffer_sum': "%fM" % round(buffer_sum/(1024 * 1024), 2),
|
|
'buffer_size': "%fMB" % round(buffer_size/(1024 * 1024), 2),
|
|
'trainable_num': "%fM" % round(trainable_num/(1024 * 1024), 10),
|
|
'trainable_size': "%fMB" % round(trainable_size/(1024 * 1024), 10)
|
|
}
|
|
|
|
|
|
def set_seed(args):
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed(args.seed)
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
def label_id_2_label_name(output_mode, label_list, prob, threshold=0.5):
|
|
'''
|
|
convect label id to label name
|
|
:param output_mode:
|
|
:param label_list:
|
|
:param prob:
|
|
:param threshold:
|
|
:return:
|
|
'''
|
|
if output_mode in ["multi-label", "multi_label"]:
|
|
res = []
|
|
pred = prob_2_pred(prob, threshold)
|
|
pred_index = relevant_indexes(pred)
|
|
for row in range(prob.shape[0]):
|
|
label_names = [label_list[idx] for idx in pred_index[row]]
|
|
res.append(label_names)
|
|
return res
|
|
elif output_mode in ["multi-class", "multi_class"]:
|
|
pred = np.argmax(prob, axis=1)
|
|
label_names = [label_list[idx] for idx in pred]
|
|
return label_names
|
|
elif output_mode in ["binary-class", "binary_class"]:
|
|
if prob.ndim == 2:
|
|
prob = prob.flatten(order="C")
|
|
pred = prob_2_pred(prob, threshold)
|
|
label_names = [label_list[idx] for idx in pred]
|
|
return label_names
|
|
else:
|
|
raise KeyError(output_mode)
|
|
|
|
|
|
def plot_bins(data, xlabel, ylabel, bins, filepath):
|
|
'''
|
|
plot bins
|
|
:param data:
|
|
:param xlabel:
|
|
:param ylabel:
|
|
:param bins: bins number
|
|
:param filepath: png save filepath
|
|
:return:
|
|
'''
|
|
plt.figure(figsize=(40, 20), dpi=100)
|
|
plt.hist(data, bins=bins)
|
|
|
|
|
|
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
if filepath is None:
|
|
plt.show()
|
|
else:
|
|
plt.savefig(filepath)
|
|
plt.clf()
|
|
plt.close()
|
|
|
|
|
|
def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
|
|
'''
|
|
:param targets: ground truth
|
|
:param preds: prediction probs
|
|
:param cm: confusion matrix
|
|
:param savepath: confusion matrix picture savepth
|
|
'''
|
|
|
|
plt.figure(figsize=(40, 20), dpi=100)
|
|
if cm is None:
|
|
cm = confusion_matrix(targets, preds, labels=[0, 1])
|
|
|
|
plt.matshow(cm, cmap=plt.cm.Oranges)
|
|
plt.colorbar()
|
|
|
|
for x in range(len(cm)):
|
|
for y in range(len(cm)):
|
|
plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
|
|
plt.ylabel('True')
|
|
plt.xlabel('Prediction')
|
|
if savepath:
|
|
plt.savefig(savepath, dpi=100)
|
|
else:
|
|
plt.show()
|
|
plt.close("all")
|
|
|
|
|
|
def save_labels(filepath, label_list):
|
|
'''
|
|
save labels
|
|
:param filepath:
|
|
:param label_list:
|
|
:return:
|
|
'''
|
|
with open(filepath, "w") as wfp:
|
|
wfp.write("label" + "\n")
|
|
for label in label_list:
|
|
wfp.write(label + "\n")
|
|
|
|
|
|
def load_labels(filepath, header=True):
|
|
'''
|
|
load labels
|
|
:param filepath:
|
|
:param header: where the file has header or not
|
|
:return:
|
|
'''
|
|
label_list = []
|
|
with open(filepath, "r") as rfp:
|
|
for label in rfp:
|
|
label_list.append(label.strip())
|
|
if len(label_list) > 0 and (header or label_list[0] == "label"):
|
|
return label_list[1:]
|
|
return label_list
|
|
|
|
|
|
def load_vocab(vocab_path):
|
|
'''
|
|
load vocab
|
|
:param vocab_path:
|
|
:return:
|
|
'''
|
|
vocab = {}
|
|
with open(vocab_path, "r") as rfp:
|
|
for line in rfp:
|
|
v = line.strip()
|
|
vocab[v] = len(vocab)
|
|
return vocab
|
|
|
|
|
|
def subprocess_popen(statement):
|
|
'''
|
|
execute shell cmd
|
|
:param statement:
|
|
:return:
|
|
'''
|
|
p = subprocess.Popen(statement, shell=True, stdout=subprocess.PIPE)
|
|
while p.poll() is None:
|
|
if p.wait() != 0:
|
|
print("fail.")
|
|
return False
|
|
else:
|
|
re = p.stdout.readlines()
|
|
result = []
|
|
for i in range(len(re)):
|
|
res = re[i].decode('utf-8').strip('\r\n')
|
|
result.append(res)
|
|
return result
|
|
|
|
|
|
def prepare_inputs(input_type, embedding_type, batch):
|
|
if input_type == "sequence":
|
|
inputs = {
|
|
"input_ids_a": batch[0],
|
|
"attention_mask_a": batch[1],
|
|
"token_type_ids_a": batch[2],
|
|
"input_ids_b": batch[4],
|
|
"attention_mask_b": batch[5],
|
|
"token_type_ids_b": batch[6],
|
|
"labels": batch[-1]
|
|
}
|
|
elif input_type == "embedding":
|
|
if embedding_type not in ["vector", "bos"]:
|
|
inputs = {
|
|
"embedding_info_a": batch[0],
|
|
"embedding_attention_mask_a": batch[1],
|
|
"embedding_info_b": batch[2],
|
|
"embedding_attention_mask_b": batch[3],
|
|
"labels": batch[-1]
|
|
}
|
|
else:
|
|
inputs = {
|
|
"embedding_info_a": batch[0],
|
|
"embedding_attention_mask_a": None,
|
|
"embedding_info_b": batch[1],
|
|
"embedding_attention_mask_b": None,
|
|
"labels": batch[-1]
|
|
}
|
|
elif input_type == "structure":
|
|
inputs = {
|
|
"struct_input_ids_a": batch[0],
|
|
"struct_contact_map_a": batch[1],
|
|
"struct_input_ids_b": batch[2],
|
|
"struct_contact_map_b": batch[3],
|
|
"labels": batch[-1]
|
|
}
|
|
elif input_type == "sefn":
|
|
if embedding_type not in ["vector", "bos"]:
|
|
inputs = {
|
|
"input_ids_a": batch[0],
|
|
"attention_mask_a": batch[1],
|
|
"token_type_ids_a": batch[2],
|
|
"embedding_info_a": batch[4],
|
|
"embedding_attention_mask_a": batch[5],
|
|
"input_ids_b": batch[6],
|
|
"attention_mask_b": batch[7],
|
|
"token_type_ids_b": batch[8],
|
|
"embedding_info_b": batch[10],
|
|
"embedding_attention_mask_b": batch[11],
|
|
"labels": batch[-1],
|
|
}
|
|
else:
|
|
inputs = {
|
|
"input_ids_a": batch[0],
|
|
"attention_mask_a": batch[1],
|
|
"token_type_ids_a": batch[2],
|
|
"embedding_info_a": batch[4],
|
|
"embedding_attention_mask_a": None,
|
|
"input_ids_b": batch[5],
|
|
"attention_mask_b": batch[6],
|
|
"token_type_ids_b": batch[7],
|
|
"embedding_info_b": batch[9],
|
|
"embedding_attention_mask_b": None,
|
|
"labels": batch[-1],
|
|
}
|
|
elif input_type == "ssfn":
|
|
inputs = {
|
|
"input_ids_a": batch[0],
|
|
"attention_mask_a": batch[1],
|
|
"token_type_ids_a": batch[2],
|
|
"struct_input_ids_a": batch[4],
|
|
"struct_contact_map_a": batch[5],
|
|
"input_ids_b": batch[6],
|
|
"attention_mask_b": batch[7],
|
|
"token_type_ids_b": batch[8],
|
|
"struct_input_ids_b": batch[10],
|
|
"struct_contact_map_b": batch[11],
|
|
"labels": batch[-1]
|
|
}
|
|
else:
|
|
inputs = None
|
|
return inputs
|
|
|
|
|
|
def gene_seq_replace_re(seq):
|
|
'''
|
|
Nucleic acid 还原
|
|
:param seq:
|
|
:return:
|
|
'''
|
|
new_seq = ""
|
|
for ch in seq:
|
|
if ch == '1':
|
|
new_seq += "A"
|
|
elif ch == '2':
|
|
new_seq += "T"
|
|
elif ch == '3':
|
|
new_seq += "C"
|
|
elif ch == '4':
|
|
new_seq += "G"
|
|
else:
|
|
new_seq += "N"
|
|
return new_seq
|
|
|
|
|
|
def gene_seq_replace(seq):
|
|
'''
|
|
Nucleic acid (gene replace: A->1, U/T->2, C->3, G->4, N->5
|
|
:param seq:
|
|
:return:
|
|
'''
|
|
new_seq = ""
|
|
for ch in seq:
|
|
if ch in ["A", "a"]:
|
|
new_seq += "1"
|
|
elif ch in ["T", "U", "t", "u"]:
|
|
new_seq += "2"
|
|
elif ch in ["C", "c"]:
|
|
new_seq += "3"
|
|
elif ch in ["G", "g"]:
|
|
new_seq += "4"
|
|
else:
|
|
new_seq += "5"
|
|
return new_seq
|
|
|
|
|
|
def get_labels(label_filepath, header=True):
|
|
'''
|
|
get labels from file, exists header
|
|
:param label_filepath:
|
|
:param header:
|
|
:return:
|
|
'''
|
|
with open(label_filepath, "r") as fp:
|
|
labels = []
|
|
multi_cols = False
|
|
cnt = 0
|
|
for line in fp:
|
|
line = line.strip()
|
|
cnt += 1
|
|
if cnt == 1 and (header or line == "label"):
|
|
if line.find(",") > 0:
|
|
multi_cols = True
|
|
continue
|
|
if multi_cols:
|
|
idx = line.find(",")
|
|
if idx > 0:
|
|
label_name = line[idx + 1:].strip()
|
|
else:
|
|
label_name = line
|
|
else:
|
|
label_name = line
|
|
labels.append(label_name)
|
|
return labels
|
|
|
|
|
|
def available_gpu_id():
|
|
'''
|
|
计算可用的GPU id
|
|
:return:
|
|
'''
|
|
pynvml.nvmlInit()
|
|
if not torch.cuda.is_available():
|
|
print("GPU not available")
|
|
return -1
|
|
|
|
device_count = pynvml.nvmlDeviceGetCount()
|
|
max_available_gpu = -1
|
|
max_available_rate = 0
|
|
|
|
|
|
for i in range(device_count):
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
|
|
|
if utilization.gpu < 10 and max_available_rate < 100 - utilization.gpu:
|
|
max_available_rate = 100 - utilization.gpu
|
|
max_available_gpu = i
|
|
|
|
if max_available_gpu > -1:
|
|
print("Available GPU ID: %d, Free Rate: %0.2f%%" % (max_available_gpu, max_available_rate))
|
|
else:
|
|
print("No Available GPU!")
|
|
|
|
|
|
pynvml.nvmlShutdown()
|
|
return max_available_gpu
|
|
|
|
|
|
def eval_metrics(output_mode, truths, preds, threshold=0.5):
|
|
'''
|
|
eval metrics
|
|
:param output_mode:
|
|
:param truths:
|
|
:param preds:
|
|
:param threshold:
|
|
:return:
|
|
'''
|
|
print("\ntruths size: ", truths.shape)
|
|
print("\npreds size: ", preds.shape)
|
|
if output_mode in ["multi-label", "multi_label"]:
|
|
return metrics_multi_label(truths, preds, threshold=threshold)
|
|
elif output_mode in ["multi-class", "multi_class"]:
|
|
return metrics_multi_class(truths, preds)
|
|
elif output_mode == "regression":
|
|
return metrics_regression(truths, preds)
|
|
elif output_mode in ["binary-class", "binary_class"]:
|
|
return metrics_binary(truths, preds, threshold=threshold)
|
|
else:
|
|
raise Exception("Not Support this output mode: %s" % output_mode)
|
|
|
|
|
|
def load_trained_model(model_config, args, model_class, model_dirpath):
|
|
|
|
print("load pretrained model: %s" % model_dirpath)
|
|
try:
|
|
model = model_class.from_pretrained(model_dirpath, args=args)
|
|
except Exception as e:
|
|
model = model_class(model_config, args=args)
|
|
pretrained_net_dict = torch.load(os.path.join(args.model_dirpath, "pytorch.pth"),
|
|
map_location=torch.device("cpu"))
|
|
model_state_dict_keys = set()
|
|
for key in model.state_dict():
|
|
model_state_dict_keys.add(key)
|
|
new_state_dict = OrderedDict()
|
|
for k, v in pretrained_net_dict.items():
|
|
if k.startswith("module."):
|
|
|
|
name = k[7:]
|
|
else:
|
|
name = k
|
|
if name in model_state_dict_keys:
|
|
new_state_dict[name] = v
|
|
|
|
|
|
model.load_state_dict(new_state_dict)
|
|
return model
|
|
|
|
|
|
def clean_seq(protein_id, seq, return_rm_index=False):
|
|
seq = seq.upper()
|
|
new_seq = ""
|
|
has_invalid_char = False
|
|
invalid_char_set = set()
|
|
return_rm_index_set = set()
|
|
for idx, ch in enumerate(seq):
|
|
if 'A' <= ch <= 'Z' and ch not in ['J']:
|
|
new_seq += ch
|
|
else:
|
|
invalid_char_set.add(ch)
|
|
return_rm_index_set.add(idx)
|
|
has_invalid_char = True
|
|
if has_invalid_char:
|
|
print("id: %s. Seq: %s" % (protein_id, seq))
|
|
print("invalid char set:", invalid_char_set)
|
|
print("return_rm_index:", return_rm_index_set)
|
|
if return_rm_index:
|
|
return new_seq, return_rm_index_set
|
|
return new_seq
|
|
|
|
|
|
def sample_size(data_dirpath):
|
|
if os.path.isdir(data_dirpath):
|
|
new_filepaths = []
|
|
for filename in os.listdir(data_dirpath):
|
|
if not filename.startswith("."):
|
|
new_filepaths.append(os.path.join(data_dirpath, filename))
|
|
filepaths = new_filepaths
|
|
else:
|
|
filepaths = [data_dirpath]
|
|
total = 0
|
|
for filepath in filepaths:
|
|
header = filepath.endswith(".tsv") or filepath.endswith(".csv")
|
|
print("sample_size filepath: %s" % filepath)
|
|
for _ in file_reader(filepath, header=header, header_filter=True):
|
|
total += 1
|
|
return total
|
|
|
|
|
|
def writer_info_tb(tb_writer, logs, global_step, prefix=None):
|
|
'''
|
|
write info to tensorboard
|
|
:param tb_writer:
|
|
:param logs:
|
|
:param global_step:
|
|
:param prefix:
|
|
:return:
|
|
'''
|
|
for key, value in logs.items():
|
|
if isinstance(value, dict):
|
|
'''
|
|
for key1, value1 in value.items():
|
|
tb_writer.add_scalar(key + "_" + key1, value1, global_step)
|
|
'''
|
|
writer_info_tb(tb_writer, value, global_step, prefix=key)
|
|
elif not math.isnan(value) and not math.isinf(value):
|
|
tb_writer.add_scalar(prefix + "_" + key if prefix else key, value, global_step)
|
|
else:
|
|
print("writer_info_tb NaN or Inf, Key-Value: %s=%s" % (key, value))
|
|
|
|
|
|
def get_lr(optimizer):
|
|
'''
|
|
get learning rate
|
|
:param optimizer:
|
|
:return:
|
|
'''
|
|
for p in optimizer.param_groups:
|
|
if "lr" in p:
|
|
return p["lr"]
|
|
|
|
|
|
def metrics_merge(results, all_results):
|
|
'''
|
|
merge metrics
|
|
:param results:
|
|
:param all_results:
|
|
:return:
|
|
'''
|
|
for item1 in results.items():
|
|
if item1[0] not in all_results:
|
|
all_results[item1[0]] = {}
|
|
for item2 in item1[1].items():
|
|
if item2[0] not in all_results[item1[0]]:
|
|
all_results[item1[0]][item2[0]] = {}
|
|
for item3 in item2[1].items():
|
|
if item3[0] not in all_results[item1[0]][item2[0]]:
|
|
all_results[item1[0]][item2[0]][item3[0]] = item3[1]
|
|
else:
|
|
all_results[item1[0]][item2[0]][item3[0]] += item3[1]
|
|
return all_results
|
|
|
|
|
|
def print_shape(item):
|
|
'''
|
|
print shape
|
|
:param item:
|
|
:return:
|
|
'''
|
|
if isinstance(item, dict):
|
|
for item1 in item.items():
|
|
print(item1[0] + ":")
|
|
print_shape(item1[1])
|
|
elif isinstance(item, list):
|
|
for idx, item1 in enumerate(item):
|
|
print("idx: %d" % idx)
|
|
print_shape(item1)
|
|
else:
|
|
print("shape:", item.shape)
|
|
|
|
|
|
def process_outputs(output_mode, truth, pred, output_truth, output_pred, ignore_index, keep_seq=False):
|
|
if keep_seq:
|
|
|
|
return None, None
|
|
else:
|
|
if output_mode in ["multi_class", "multi-class"]:
|
|
cur_truth = truth.view(-1)
|
|
cur_mask = cur_truth != ignore_index
|
|
cur_pred = pred.view(-1, pred.shape[-1])
|
|
cur_truth = cur_truth[cur_mask]
|
|
cur_pred = cur_pred[cur_mask, :]
|
|
sum_v = cur_mask.sum().item()
|
|
elif output_mode in ["multi_label", "multi-label"]:
|
|
cur_truth = truth.view(-1, truth.shape[-1])
|
|
cur_pred = pred.view(-1, pred.shape[-1])
|
|
sum_v = pred.shape[0]
|
|
elif output_mode in ["binary_class", "binary-class"]:
|
|
cur_truth = truth.view(-1)
|
|
cur_mask = cur_truth != ignore_index
|
|
cur_pred = pred.view(-1)
|
|
cur_truth = cur_truth[cur_mask]
|
|
cur_pred = cur_pred[cur_mask]
|
|
sum_v = cur_mask.sum().item()
|
|
elif output_mode in ["regression"]:
|
|
cur_truth = truth.view(-1)
|
|
cur_mask = cur_truth != ignore_index
|
|
cur_pred = pred.view(-1)
|
|
cur_truth = cur_truth[cur_mask]
|
|
cur_pred = cur_pred[cur_mask]
|
|
sum_v = cur_mask.sum().item()
|
|
else:
|
|
raise Exception("not output mode: %s" % output_mode)
|
|
if sum_v > 0:
|
|
cur_truth = cur_truth.detach().cpu().numpy()
|
|
cur_pred = cur_pred.detach().cpu().numpy()
|
|
if output_truth is None or output_pred is None:
|
|
return cur_truth, cur_pred
|
|
else:
|
|
output_truth = np.append(output_truth, cur_truth, axis=0)
|
|
output_pred = np.append(output_pred, cur_pred, axis=0)
|
|
return output_truth, output_pred
|
|
return truth, pred
|
|
|
|
|
|
def print_batch(value, key=None, debug_path=None, wfp=None, local_rank=-1):
|
|
'''
|
|
print a batch
|
|
:param value:
|
|
:param key:
|
|
:param debug_path:
|
|
:param wfp:
|
|
:param local_rank:
|
|
:return:
|
|
'''
|
|
if isinstance(value, list):
|
|
for idx, v in enumerate(value):
|
|
if wfp is not None:
|
|
if v is not None:
|
|
wfp.write(str([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)]) + "\n")
|
|
wfp.write(str(v.shape) + "\n")
|
|
else:
|
|
wfp.write("None\n")
|
|
wfp.write("-" * 10 + "\n")
|
|
else:
|
|
if v is not None:
|
|
print([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)])
|
|
print(v.shape)
|
|
else:
|
|
print("None")
|
|
print("-" * 50)
|
|
if v is not None:
|
|
try:
|
|
value = v.detach().cpu().numpy().astype(int)
|
|
if debug_path is not None:
|
|
if value.ndim == 3:
|
|
for dim_1_idx in range(value.shape[0]):
|
|
np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
|
|
else:
|
|
np.savetxt(os.path.join(debug_path, "%d.txt" % idx), value, fmt='%i', delimiter=",")
|
|
else:
|
|
if value.ndim == 3:
|
|
for dim_1_idx in range(value.shape[0]):
|
|
np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
|
|
else:
|
|
np.savetxt("%d.txt" % idx, value, fmt='%i', delimiter=",")
|
|
except Exception as e:
|
|
print(e)
|
|
elif isinstance(value, dict):
|
|
for item in value.items():
|
|
if wfp is not None:
|
|
wfp.write(str(item[0]) + ":\n")
|
|
else:
|
|
print(str(item[0]) + ':')
|
|
print_batch(item[1], item[0], debug_path, wfp, local_rank)
|
|
else:
|
|
if wfp is not None:
|
|
if value is not None:
|
|
wfp.write(str([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)]) + "\n")
|
|
wfp.write(str(value.shape) + "\n")
|
|
else:
|
|
wfp.write("None\n")
|
|
wfp.write("-" * 10 + "\n")
|
|
else:
|
|
if value is not None:
|
|
print([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)])
|
|
print(value.shape)
|
|
else:
|
|
print("None")
|
|
print("-" * 10)
|
|
if value is not None:
|
|
if key != "prot_structure":
|
|
fmt = '%i'
|
|
d_type = int
|
|
else:
|
|
fmt = '%0.4f'
|
|
d_type = float
|
|
try:
|
|
value = value.detach().cpu().numpy().astype(d_type)
|
|
if debug_path is not None:
|
|
if value.ndim == 3:
|
|
for dim_1_idx in range(value.shape[0]):
|
|
np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
|
|
else:
|
|
np.savetxt(os.path.join(debug_path, "%s.txt" % key), value, fmt=fmt, delimiter=",")
|
|
else:
|
|
if value.ndim == 3:
|
|
for dim_1_idx in range(value.shape[0]):
|
|
np.savetxt("%s_batch_%d.txt" % (key, dim_1_idx), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
|
|
else:
|
|
np.savetxt("%s.txt" % key, value, fmt=fmt, delimiter=",")
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
|
|
def gcd(x, y):
|
|
'''
|
|
最大公约数
|
|
:param x:
|
|
:param y:
|
|
:return:
|
|
'''
|
|
m = max(x, y)
|
|
n = min(x, y)
|
|
while m % n:
|
|
m, n = n, m % n
|
|
return n
|
|
|
|
|
|
def lcm(x, y):
|
|
'''
|
|
最小公倍数
|
|
:param x:
|
|
:param y:
|
|
:return:
|
|
'''
|
|
m = max(x, y)
|
|
n = min(x, y)
|
|
while m % n:
|
|
m, n = n, m % n
|
|
return x*y//n
|
|
|
|
|
|
def device_memory(gpu_id):
|
|
if gpu_id is None or gpu_id < 0:
|
|
return
|
|
pynvml.nvmlInit()
|
|
device_cnt = pynvml.nvmlDeviceGetCount()
|
|
for idx in range(device_cnt):
|
|
if gpu_id is not None and gpu_id != idx:
|
|
continue
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
print(f"Device {idx}: {pynvml.nvmlDeviceGetName(handle)}")
|
|
print(f"Total memory: {info.total / 1024**3:.8f} GB")
|
|
print(f"Used memory: {info.used / 1024**3:.8f} GB")
|
|
print(f"Free memory: {info.free / 1024**3:.8f} GB")
|
|
pynvml.nvmlShutdown()
|
|
|
|
|
|
def calc_emb_filename_by_seq_id(seq_id, embedding_type):
|
|
"""
|
|
根据seq_id得到emb_filename
|
|
:param seq_id:
|
|
:param embedding_type:
|
|
:return:
|
|
"""
|
|
if seq_id[0] == ">":
|
|
seq_id = seq_id[1:]
|
|
if "|" in seq_id:
|
|
strs = seq_id.split("|")
|
|
if len(strs) > 1:
|
|
emb_filename = embedding_type + "_" + strs[1].strip() + ".pt"
|
|
else:
|
|
emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
|
|
else:
|
|
emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
|
|
return emb_filename
|
|
|
|
|
|
def download_file(url, local_filename):
|
|
with requests.get(url, stream=True) as r:
|
|
r.raise_for_status()
|
|
dir_name = os.path.dirname(local_filename)
|
|
if not os.path.exists(dir_name):
|
|
os.makedirs(dir_name)
|
|
with open(local_filename, 'wb') as f:
|
|
for chunk in r.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
f.write(chunk)
|
|
return local_filename
|
|
|
|
|
|
def download_folder(base_url, file_names, local_dir):
|
|
if not os.path.exists(local_dir):
|
|
os.makedirs(local_dir)
|
|
|
|
for file_name in file_names:
|
|
file_url = f"{base_url}/{file_name}"
|
|
local_filename = os.path.join(local_dir, file_name)
|
|
download_file(file_url, local_filename)
|
|
print(f"Downloaded {file_name}")
|
|
|
|
|
|
def download_trained_checkpoint_lucaone(
|
|
llm_dir,
|
|
llm_type="lucaone_gplm",
|
|
llm_version="v2.0",
|
|
llm_task_level="token_level,span_level,seq_level,structure_level",
|
|
llm_time_str="20231125113045",
|
|
llm_step="5600000",
|
|
base_url="http://47.93.21.181/lucaone/TrainedCheckPoint"
|
|
):
|
|
"""
|
|
donwload trained checkpoint of LucaOne
|
|
:param llm_dir:
|
|
:param llm_type:
|
|
:param llm_version:
|
|
:param llm_task_level:
|
|
:param llm_time_str:
|
|
:param llm_step:
|
|
:param base_url:
|
|
:return:
|
|
"""
|
|
print("------Download Trained LLM(LucaOne)------")
|
|
try:
|
|
logs_file_names = ["logs.txt"]
|
|
models_file_names = ["config.json", "pytorch.pth", "training_args.bin", "tokenizer/alphabet.pkl"]
|
|
logs_path = "logs/lucagplm/%s/%s/%s/%s" % (llm_version, llm_task_level, llm_type, llm_time_str)
|
|
models_path = "models/lucagplm/%s/%s/%s/%s/checkpoint-step%s" % (llm_version, llm_task_level, llm_type, llm_time_str, llm_step)
|
|
logs_local_dir = os.path.join(llm_dir, logs_path)
|
|
exists = True
|
|
for logs_file_name in logs_file_names:
|
|
if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
|
|
exists = False
|
|
break
|
|
models_local_dir = os.path.join(llm_dir, models_path)
|
|
if exists:
|
|
for models_file_name in models_file_names:
|
|
if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
|
|
exists = False
|
|
break
|
|
if not exists:
|
|
print("*" * 20 + "Downloading" + "*" * 20)
|
|
print("Downloading LucaOne TrainedCheckPoint: LucaOne-%s-%s-%s ..." % (llm_version, llm_time_str, llm_step))
|
|
print("Wait a moment, please.")
|
|
|
|
if not os.path.exists(logs_local_dir):
|
|
os.makedirs(logs_local_dir)
|
|
logs_base_url = os.path.join(base_url, logs_path)
|
|
download_folder(logs_base_url, logs_file_names, logs_local_dir)
|
|
|
|
if not os.path.exists(models_local_dir):
|
|
os.makedirs(models_local_dir)
|
|
models_base_url = os.path.join(base_url, models_path)
|
|
download_folder(models_base_url, models_file_names, models_local_dir)
|
|
print("LucaOne Download Succeed.")
|
|
print("*" * 50)
|
|
except Exception as e:
|
|
print(e)
|
|
print("Download automatically LucaOne Trained CheckPoint failed!")
|
|
print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(llm_dir), os.path.join(base_url, "TrainedCheckPoint/")))
|
|
raise Exception(e)
|
|
|
|
|
|
def download_trained_checkpoint_downstream_tasks(
|
|
save_dir="../",
|
|
dataset_name=["CentralDogma", "GenusTax", "InfA", "ncRNAFam", "ncRPI", "PPI", "ProtLoc", "ProtStab", "SpeciesTax", "SupKTax"],
|
|
dataset_type=["gene_protein", "gene", "gene_gene", "gene", "gene_protein", "protein", "protein", "protein", "gene", "gene"],
|
|
task_type=["binary_class", "multi_class", "binary_class", "multi_class", "binary_class", "binary_class", "multi_class", "regression", "multi_class", "multi_class"],
|
|
model_type=["lucappi2", "luca_base", "lucappi", "luca_base", "lucappi2", "lucappi", "luca_base", "luca_base", "luca_base", "luca_base"],
|
|
input_type=["matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix"],
|
|
time_str=["20240406173806", "20240412100337", "20240214105653", "20240414155526", "20240404105148", "20240216205421", "20240412140824", "20240404104215", "20240411144916", "20240212202328"],
|
|
step=[64000, 24500, 9603, 1958484, 716380, 52304, 466005, 70371, 24000, 37000],
|
|
base_url="http://47.93.21.181/lucaone/DownstreamTasksTrainedModels"
|
|
):
|
|
"""
|
|
donwload trained downstream task models
|
|
:param save_dir: 本地保存路径
|
|
:param dataset_name:
|
|
:param dataset_type:
|
|
:param task_type:
|
|
:param model_type:
|
|
:param input_type:
|
|
:param time_str:
|
|
:param step:
|
|
:param base_url:
|
|
:return:
|
|
"""
|
|
assert len(dataset_name) == len(dataset_type) == len(task_type) == \
|
|
len(model_type) == len(input_type) == len(time_str) == len(step)
|
|
assert isinstance(dataset_name, list)
|
|
assert isinstance(dataset_type, list)
|
|
assert isinstance(task_type, list)
|
|
assert isinstance(model_type, list)
|
|
assert isinstance(input_type, list)
|
|
assert isinstance(time_str, list)
|
|
assert isinstance(step, list)
|
|
download_succeed_task_num = 0
|
|
print("------Download Trained Models------")
|
|
for idx in range(len(dataset_name)):
|
|
try:
|
|
logs_file_names = ["logs.txt", "label.txt"]
|
|
models_file_names = ["config.json", "pytorch_model.bin", "training_args.bin", "tokenizer/alphabet.pkl"]
|
|
logs_path = "logs/%s/%s/%s/%s/%s/%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx])
|
|
models_path = "models/%s/%s/%s/%s/%s/%s/checkpoint-%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx], str(step[idx]))
|
|
logs_local_dir = os.path.join(save_dir, logs_path)
|
|
exists = True
|
|
for logs_file_name in logs_file_names:
|
|
if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
|
|
exists = False
|
|
break
|
|
models_local_dir = os.path.join(save_dir, models_path)
|
|
if exists:
|
|
for models_file_name in models_file_names:
|
|
if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
|
|
exists = False
|
|
break
|
|
if not exists:
|
|
print("*" * 20 + "Downloading" + "*" * 20)
|
|
print("Downloading Downstream Task: %s TrainedCheckPoint: %s-%s-%s ..." % (dataset_name[idx], dataset_name[idx], time_str[idx], str(step[idx])))
|
|
print("Wait a moment, please.")
|
|
|
|
if not os.path.exists(logs_local_dir):
|
|
os.makedirs(logs_local_dir)
|
|
logs_base_url = os.path.join(base_url, dataset_name[idx], logs_path)
|
|
download_folder(logs_base_url, logs_file_names, logs_local_dir)
|
|
|
|
if not os.path.exists(models_local_dir):
|
|
os.makedirs(models_local_dir)
|
|
models_base_url = os.path.join(base_url, dataset_name[idx], models_path)
|
|
download_folder(models_base_url, models_file_names, models_local_dir)
|
|
print("Downstream Task: %s Trained Model Download Succeed." % dataset_name[idx])
|
|
print("*" * 50)
|
|
download_succeed_task_num += 1
|
|
except Exception as e:
|
|
print(e)
|
|
print("Download automatically LucaDownstream Task: %s Trained CheckPoint failed!" % dataset_name[idx])
|
|
print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(save_dir), os.path.join(base_url, dataset_name[idx])))
|
|
raise Exception(e)
|
|
print("%d Downstream Task Trained Model Download Succeed." % download_succeed_task_num) |