LucaOne / utils.py
Yuanfei's picture
Upload LucaGPLM
96c0ca2 verified
#!/usr/bin/env python
# encoding: utf-8
import math
import os, csv, json
import io, textwrap, itertools
import subprocess
from Bio import SeqIO
import torch
import numpy as np
import sys, random
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pynvml, requests
from collections import OrderedDict
plt.rcParams.update({'font.size': 18})
plt.rcParams['axes.unicode_minus'] = False
from .file_operator import file_reader
from .multi_label_metrics import prob_2_pred, relevant_indexes, metrics_multi_label
from .metrics import metrics_multi_class, metrics_binary, metrics_regression
common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
# not {'O', 'U', 'Z', 'J', 'B'}
# Common amino acids
common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
def to_device(device, batch):
'''
input to device
:param device:
:param batch:
:return:
'''
new_batch = {}
sample_num = 0
tens = None
for item1 in batch.items():
new_batch[item1[0]] = {}
if isinstance(item1[1], dict):
for item2 in item1[1].items():
new_batch[item1[0]][item2[0]] = {}
if isinstance(item2[1], dict):
for item3 in item2[1].items():
if item3[1] is not None and not isinstance(item3[1], int) and not isinstance(item3[1], str) and not isinstance(item3[1], float):
new_batch[item1[0]][item2[0]][item3[0]] = item3[1].to(device)
tens = item3[1]
else:
new_batch[item1[0]][item2[0]][item3[0]] = item3[1]
else:
if item2[1] is not None and not isinstance(item2[1], int) and not isinstance(item2[1], str) and not isinstance(item2[1], float):
new_batch[item1[0]][item2[0]] = item2[1].to(device)
tens = item2[1]
else:
new_batch[item1[0]][item2[0]] = item2[1]
else:
if item1[1] is not None and not isinstance(item1[1], int) and not isinstance(item1[1], str) and not isinstance(item1[1], float):
new_batch[item1[0]] = item1[1].to(device)
tens = item1[1]
else:
new_batch[item1[0]] = item1[1]
if tens is not None:
sample_num = tens.shape[0]
return new_batch, sample_num
def get_parameter_number(model):
'''
colc the parameter number of the model
:param model:
:return:
'''
param_size = 0
param_sum = 0
trainable_size = 0
trainable_num = 0
for param in model.parameters():
cur_size = param.nelement() * param.element_size()
cur_num = param.nelement()
param_size += cur_size
param_sum += cur_num
if param.requires_grad:
trainable_size += cur_size
trainable_num += cur_num
buffer_size = 0
buffer_sum = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
buffer_sum += buffer.nelement()
'''
total_num = sum(p.numel() for p in model.parameters())
total_size = sum(p.numel() * p.element_size() for p in model.parameters())
total_num += sum(p.numel() for p in model.buffers())
total_size += sum(p.numel() * p.element_size() for p in model.buffers())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_size = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
'''
return {
'total_num': "%fM" % round((buffer_sum + param_sum)/(1024 * 1024), 2),
'total_size': "%fMB" % round((buffer_size + param_size)/(1024 * 1024), 2),
'param_sum': "%fM" % round(param_sum/(1024 * 1024), 2),
'param_size': "%fMB" % round(param_size/(1024 * 1024), 2),
'buffer_sum': "%fM" % round(buffer_sum/(1024 * 1024), 2),
'buffer_size': "%fMB" % round(buffer_size/(1024 * 1024), 2),
'trainable_num': "%fM" % round(trainable_num/(1024 * 1024), 10),
'trainable_size': "%fMB" % round(trainable_size/(1024 * 1024), 10)
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
def label_id_2_label_name(output_mode, label_list, prob, threshold=0.5):
'''
convect label id to label name
:param output_mode:
:param label_list:
:param prob:
:param threshold:
:return:
'''
if output_mode in ["multi-label", "multi_label"]:
res = []
pred = prob_2_pred(prob, threshold)
pred_index = relevant_indexes(pred)
for row in range(prob.shape[0]):
label_names = [label_list[idx] for idx in pred_index[row]]
res.append(label_names)
return res
elif output_mode in ["multi-class", "multi_class"]:
pred = np.argmax(prob, axis=1)
label_names = [label_list[idx] for idx in pred]
return label_names
elif output_mode in ["binary-class", "binary_class"]:
if prob.ndim == 2:
prob = prob.flatten(order="C")
pred = prob_2_pred(prob, threshold)
label_names = [label_list[idx] for idx in pred]
return label_names
else:
raise KeyError(output_mode)
def plot_bins(data, xlabel, ylabel, bins, filepath):
'''
plot bins
:param data:
:param xlabel:
:param ylabel:
:param bins: bins number
:param filepath: png save filepath
:return:
'''
plt.figure(figsize=(40, 20), dpi=100)
plt.hist(data, bins=bins)
# plt.xticks(range(min(data), max(data)))
# plt.grid(linestyle='--', alpha=0.5)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if filepath is None:
plt.show()
else:
plt.savefig(filepath)
plt.clf()
plt.close()
def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
'''
:param targets: ground truth
:param preds: prediction probs
:param cm: confusion matrix
:param savepath: confusion matrix picture savepth
'''
plt.figure(figsize=(40, 20), dpi=100)
if cm is None:
cm = confusion_matrix(targets, preds, labels=[0, 1])
plt.matshow(cm, cmap=plt.cm.Oranges)
plt.colorbar()
for x in range(len(cm)):
for y in range(len(cm)):
plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
plt.ylabel('True')
plt.xlabel('Prediction')
if savepath:
plt.savefig(savepath, dpi=100)
else:
plt.show()
plt.close("all")
def save_labels(filepath, label_list):
'''
save labels
:param filepath:
:param label_list:
:return:
'''
with open(filepath, "w") as wfp:
wfp.write("label" + "\n")
for label in label_list:
wfp.write(label + "\n")
def load_labels(filepath, header=True):
'''
load labels
:param filepath:
:param header: where the file has header or not
:return:
'''
label_list = []
with open(filepath, "r") as rfp:
for label in rfp:
label_list.append(label.strip())
if len(label_list) > 0 and (header or label_list[0] == "label"):
return label_list[1:]
return label_list
def load_vocab(vocab_path):
'''
load vocab
:param vocab_path:
:return:
'''
vocab = {}
with open(vocab_path, "r") as rfp:
for line in rfp:
v = line.strip()
vocab[v] = len(vocab)
return vocab
def subprocess_popen(statement):
'''
execute shell cmd
:param statement:
:return:
'''
p = subprocess.Popen(statement, shell=True, stdout=subprocess.PIPE)
while p.poll() is None:
if p.wait() != 0:
print("fail.")
return False
else:
re = p.stdout.readlines()
result = []
for i in range(len(re)):
res = re[i].decode('utf-8').strip('\r\n')
result.append(res)
return result
def prepare_inputs(input_type, embedding_type, batch):
if input_type == "sequence":
inputs = {
"input_ids_a": batch[0],
"attention_mask_a": batch[1],
"token_type_ids_a": batch[2],
"input_ids_b": batch[4],
"attention_mask_b": batch[5],
"token_type_ids_b": batch[6],
"labels": batch[-1]
}
elif input_type == "embedding":
if embedding_type not in ["vector", "bos"]:
inputs = {
"embedding_info_a": batch[0],
"embedding_attention_mask_a": batch[1],
"embedding_info_b": batch[2],
"embedding_attention_mask_b": batch[3],
"labels": batch[-1]
}
else:
inputs = {
"embedding_info_a": batch[0],
"embedding_attention_mask_a": None,
"embedding_info_b": batch[1],
"embedding_attention_mask_b": None,
"labels": batch[-1]
}
elif input_type == "structure":
inputs = {
"struct_input_ids_a": batch[0],
"struct_contact_map_a": batch[1],
"struct_input_ids_b": batch[2],
"struct_contact_map_b": batch[3],
"labels": batch[-1]
}
elif input_type == "sefn":
if embedding_type not in ["vector", "bos"]:
inputs = {
"input_ids_a": batch[0],
"attention_mask_a": batch[1],
"token_type_ids_a": batch[2],
"embedding_info_a": batch[4],
"embedding_attention_mask_a": batch[5],
"input_ids_b": batch[6],
"attention_mask_b": batch[7],
"token_type_ids_b": batch[8],
"embedding_info_b": batch[10],
"embedding_attention_mask_b": batch[11],
"labels": batch[-1],
}
else:
inputs = {
"input_ids_a": batch[0],
"attention_mask_a": batch[1],
"token_type_ids_a": batch[2],
"embedding_info_a": batch[4],
"embedding_attention_mask_a": None,
"input_ids_b": batch[5],
"attention_mask_b": batch[6],
"token_type_ids_b": batch[7],
"embedding_info_b": batch[9],
"embedding_attention_mask_b": None,
"labels": batch[-1],
}
elif input_type == "ssfn":
inputs = {
"input_ids_a": batch[0],
"attention_mask_a": batch[1],
"token_type_ids_a": batch[2],
"struct_input_ids_a": batch[4],
"struct_contact_map_a": batch[5],
"input_ids_b": batch[6],
"attention_mask_b": batch[7],
"token_type_ids_b": batch[8],
"struct_input_ids_b": batch[10],
"struct_contact_map_b": batch[11],
"labels": batch[-1]
}
else:
inputs = None
return inputs
def gene_seq_replace_re(seq):
'''
Nucleic acid 还原
:param seq:
:return:
'''
new_seq = ""
for ch in seq:
if ch == '1':
new_seq += "A"
elif ch == '2':
new_seq += "T"
elif ch == '3':
new_seq += "C"
elif ch == '4':
new_seq += "G"
else: # unknown
new_seq += "N"
return new_seq
def gene_seq_replace(seq):
'''
Nucleic acid (gene replace: A->1, U/T->2, C->3, G->4, N->5
:param seq:
:return:
'''
new_seq = ""
for ch in seq:
if ch in ["A", "a"]:
new_seq += "1"
elif ch in ["T", "U", "t", "u"]:
new_seq += "2"
elif ch in ["C", "c"]:
new_seq += "3"
elif ch in ["G", "g"]:
new_seq += "4"
else: # unknown
new_seq += "5"
return new_seq
def get_labels(label_filepath, header=True):
'''
get labels from file, exists header
:param label_filepath:
:param header:
:return:
'''
with open(label_filepath, "r") as fp:
labels = []
multi_cols = False
cnt = 0
for line in fp:
line = line.strip()
cnt += 1
if cnt == 1 and (header or line == "label"):
if line.find(",") > 0:
multi_cols = True
continue
if multi_cols:
idx = line.find(",")
if idx > 0:
label_name = line[idx + 1:].strip()
else:
label_name = line
else:
label_name = line
labels.append(label_name)
return labels
def available_gpu_id():
'''
计算可用的GPU id
:return:
'''
pynvml.nvmlInit()
if not torch.cuda.is_available():
print("GPU not available")
return -1
# 获取GPU数量
device_count = pynvml.nvmlDeviceGetCount()
max_available_gpu = -1
max_available_rate = 0
# 遍历所有GPU并检查可用性
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
# 假设如果GPU利用率小于某个阈值(例如10%),我们认为这个GPU目前是空闲的
if utilization.gpu < 10 and max_available_rate < 100 - utilization.gpu:
max_available_rate = 100 - utilization.gpu
max_available_gpu = i
# 打印可用的GPU ID
if max_available_gpu > -1:
print("Available GPU ID: %d, Free Rate: %0.2f%%" % (max_available_gpu, max_available_rate))
else:
print("No Available GPU!")
# Shutdown NVML
pynvml.nvmlShutdown()
return max_available_gpu
def eval_metrics(output_mode, truths, preds, threshold=0.5):
'''
eval metrics
:param output_mode:
:param truths:
:param preds:
:param threshold:
:return:
'''
print("\ntruths size: ", truths.shape)
print("\npreds size: ", preds.shape)
if output_mode in ["multi-label", "multi_label"]:
return metrics_multi_label(truths, preds, threshold=threshold)
elif output_mode in ["multi-class", "multi_class"]:
return metrics_multi_class(truths, preds)
elif output_mode == "regression":
return metrics_regression(truths, preds)
elif output_mode in ["binary-class", "binary_class"]:
return metrics_binary(truths, preds, threshold=threshold)
else:
raise Exception("Not Support this output mode: %s" % output_mode)
def load_trained_model(model_config, args, model_class, model_dirpath):
# load exists checkpoint
print("load pretrained model: %s" % model_dirpath)
try:
model = model_class.from_pretrained(model_dirpath, args=args)
except Exception as e:
model = model_class(model_config, args=args)
pretrained_net_dict = torch.load(os.path.join(args.model_dirpath, "pytorch.pth"),
map_location=torch.device("cpu"))
model_state_dict_keys = set()
for key in model.state_dict():
model_state_dict_keys.add(key)
new_state_dict = OrderedDict()
for k, v in pretrained_net_dict.items():
if k.startswith("module."):
# remove `module.`
name = k[7:]
else:
name = k
if name in model_state_dict_keys:
new_state_dict[name] = v
# print("diff:")
# print(model_state_dict_keys.difference(new_state_dict.keys()))
model.load_state_dict(new_state_dict)
return model
def clean_seq(protein_id, seq, return_rm_index=False):
seq = seq.upper()
new_seq = ""
has_invalid_char = False
invalid_char_set = set()
return_rm_index_set = set()
for idx, ch in enumerate(seq):
if 'A' <= ch <= 'Z' and ch not in ['J']:
new_seq += ch
else:
invalid_char_set.add(ch)
return_rm_index_set.add(idx)
has_invalid_char = True
if has_invalid_char:
print("id: %s. Seq: %s" % (protein_id, seq))
print("invalid char set:", invalid_char_set)
print("return_rm_index:", return_rm_index_set)
if return_rm_index:
return new_seq, return_rm_index_set
return new_seq
def sample_size(data_dirpath):
if os.path.isdir(data_dirpath):
new_filepaths = []
for filename in os.listdir(data_dirpath):
if not filename.startswith("."):
new_filepaths.append(os.path.join(data_dirpath, filename))
filepaths = new_filepaths
else:
filepaths = [data_dirpath]
total = 0
for filepath in filepaths:
header = filepath.endswith(".tsv") or filepath.endswith(".csv")
print("sample_size filepath: %s" % filepath)
for _ in file_reader(filepath, header=header, header_filter=True):
total += 1
return total
def writer_info_tb(tb_writer, logs, global_step, prefix=None):
'''
write info to tensorboard
:param tb_writer:
:param logs:
:param global_step:
:param prefix:
:return:
'''
for key, value in logs.items():
if isinstance(value, dict):
'''
for key1, value1 in value.items():
tb_writer.add_scalar(key + "_" + key1, value1, global_step)
'''
writer_info_tb(tb_writer, value, global_step, prefix=key)
elif not math.isnan(value) and not math.isinf(value):
tb_writer.add_scalar(prefix + "_" + key if prefix else key, value, global_step)
else:
print("writer_info_tb NaN or Inf, Key-Value: %s=%s" % (key, value))
def get_lr(optimizer):
'''
get learning rate
:param optimizer:
:return:
'''
for p in optimizer.param_groups:
if "lr" in p:
return p["lr"]
def metrics_merge(results, all_results):
'''
merge metrics
:param results:
:param all_results:
:return:
'''
for item1 in results.items():
if item1[0] not in all_results:
all_results[item1[0]] = {}
for item2 in item1[1].items():
if item2[0] not in all_results[item1[0]]:
all_results[item1[0]][item2[0]] = {}
for item3 in item2[1].items():
if item3[0] not in all_results[item1[0]][item2[0]]:
all_results[item1[0]][item2[0]][item3[0]] = item3[1]
else:
all_results[item1[0]][item2[0]][item3[0]] += item3[1]
return all_results
def print_shape(item):
'''
print shape
:param item:
:return:
'''
if isinstance(item, dict):
for item1 in item.items():
print(item1[0] + ":")
print_shape(item1[1])
elif isinstance(item, list):
for idx, item1 in enumerate(item):
print("idx: %d" % idx)
print_shape(item1)
else:
print("shape:", item.shape)
def process_outputs(output_mode, truth, pred, output_truth, output_pred, ignore_index, keep_seq=False):
if keep_seq:
# to do
return None, None
else:
if output_mode in ["multi_class", "multi-class"]:
cur_truth = truth.view(-1)
cur_mask = cur_truth != ignore_index
cur_pred = pred.view(-1, pred.shape[-1])
cur_truth = cur_truth[cur_mask]
cur_pred = cur_pred[cur_mask, :]
sum_v = cur_mask.sum().item()
elif output_mode in ["multi_label", "multi-label"]:
cur_truth = truth.view(-1, truth.shape[-1])
cur_pred = pred.view(-1, pred.shape[-1])
sum_v = pred.shape[0]
elif output_mode in ["binary_class", "binary-class"]:
cur_truth = truth.view(-1)
cur_mask = cur_truth != ignore_index
cur_pred = pred.view(-1)
cur_truth = cur_truth[cur_mask]
cur_pred = cur_pred[cur_mask]
sum_v = cur_mask.sum().item()
elif output_mode in ["regression"]:
cur_truth = truth.view(-1)
cur_mask = cur_truth != ignore_index
cur_pred = pred.view(-1)
cur_truth = cur_truth[cur_mask]
cur_pred = cur_pred[cur_mask]
sum_v = cur_mask.sum().item()
else:
raise Exception("not output mode: %s" % output_mode)
if sum_v > 0:
cur_truth = cur_truth.detach().cpu().numpy()
cur_pred = cur_pred.detach().cpu().numpy()
if output_truth is None or output_pred is None:
return cur_truth, cur_pred
else:
output_truth = np.append(output_truth, cur_truth, axis=0)
output_pred = np.append(output_pred, cur_pred, axis=0)
return output_truth, output_pred
return truth, pred
def print_batch(value, key=None, debug_path=None, wfp=None, local_rank=-1):
'''
print a batch
:param value:
:param key:
:param debug_path:
:param wfp:
:param local_rank:
:return:
'''
if isinstance(value, list):
for idx, v in enumerate(value):
if wfp is not None:
if v is not None:
wfp.write(str([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)]) + "\n")
wfp.write(str(v.shape) + "\n")
else:
wfp.write("None\n")
wfp.write("-" * 10 + "\n")
else:
if v is not None:
print([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)])
print(v.shape)
else:
print("None")
print("-" * 50)
if v is not None:
try:
value = v.detach().cpu().numpy().astype(int)
if debug_path is not None:
if value.ndim == 3:
for dim_1_idx in range(value.shape[0]):
np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
else:
np.savetxt(os.path.join(debug_path, "%d.txt" % idx), value, fmt='%i', delimiter=",")
else:
if value.ndim == 3:
for dim_1_idx in range(value.shape[0]):
np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
else:
np.savetxt("%d.txt" % idx, value, fmt='%i', delimiter=",")
except Exception as e:
print(e)
elif isinstance(value, dict):
for item in value.items():
if wfp is not None:
wfp.write(str(item[0]) + ":\n")
else:
print(str(item[0]) + ':')
print_batch(item[1], item[0], debug_path, wfp, local_rank)
else:
if wfp is not None:
if value is not None:
wfp.write(str([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)]) + "\n")
wfp.write(str(value.shape) + "\n")
else:
wfp.write("None\n")
wfp.write("-" * 10 + "\n")
else:
if value is not None:
print([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)])
print(value.shape)
else:
print("None")
print("-" * 10)
if value is not None:
if key != "prot_structure":
fmt = '%i'
d_type = int
else:
fmt = '%0.4f'
d_type = float
try:
value = value.detach().cpu().numpy().astype(d_type)
if debug_path is not None:
if value.ndim == 3:
for dim_1_idx in range(value.shape[0]):
np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
else:
np.savetxt(os.path.join(debug_path, "%s.txt" % key), value, fmt=fmt, delimiter=",")
else:
if value.ndim == 3:
for dim_1_idx in range(value.shape[0]):
np.savetxt("%s_batch_%d.txt" % (key, dim_1_idx), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
else:
np.savetxt("%s.txt" % key, value, fmt=fmt, delimiter=",")
except Exception as e:
print(e)
def gcd(x, y):
'''
最大公约数
:param x:
:param y:
:return:
'''
m = max(x, y)
n = min(x, y)
while m % n:
m, n = n, m % n
return n
def lcm(x, y):
'''
最小公倍数
:param x:
:param y:
:return:
'''
m = max(x, y)
n = min(x, y)
while m % n:
m, n = n, m % n
return x*y//n
def device_memory(gpu_id):
if gpu_id is None or gpu_id < 0:
return
pynvml.nvmlInit()
device_cnt = pynvml.nvmlDeviceGetCount()
for idx in range(device_cnt):
if gpu_id is not None and gpu_id != idx:
continue
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
print(f"Device {idx}: {pynvml.nvmlDeviceGetName(handle)}")
print(f"Total memory: {info.total / 1024**3:.8f} GB")
print(f"Used memory: {info.used / 1024**3:.8f} GB")
print(f"Free memory: {info.free / 1024**3:.8f} GB")
pynvml.nvmlShutdown()
def calc_emb_filename_by_seq_id(seq_id, embedding_type):
"""
根据seq_id得到emb_filename
:param seq_id:
:param embedding_type:
:return:
"""
if seq_id[0] == ">":
seq_id = seq_id[1:]
if "|" in seq_id:
strs = seq_id.split("|")
if len(strs) > 1:
emb_filename = embedding_type + "_" + strs[1].strip() + ".pt"
else:
emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
else:
emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
return emb_filename
def download_file(url, local_filename):
with requests.get(url, stream=True) as r:
r.raise_for_status()
dir_name = os.path.dirname(local_filename)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
return local_filename
def download_folder(base_url, file_names, local_dir):
if not os.path.exists(local_dir):
os.makedirs(local_dir)
for file_name in file_names:
file_url = f"{base_url}/{file_name}"
local_filename = os.path.join(local_dir, file_name)
download_file(file_url, local_filename)
print(f"Downloaded {file_name}")
def download_trained_checkpoint_lucaone(
llm_dir,
llm_type="lucaone_gplm",
llm_version="v2.0",
llm_task_level="token_level,span_level,seq_level,structure_level",
llm_time_str="20231125113045",
llm_step="5600000",
base_url="http://47.93.21.181/lucaone/TrainedCheckPoint"
):
"""
donwload trained checkpoint of LucaOne
:param llm_dir:
:param llm_type:
:param llm_version:
:param llm_task_level:
:param llm_time_str:
:param llm_step:
:param base_url:
:return:
"""
print("------Download Trained LLM(LucaOne)------")
try:
logs_file_names = ["logs.txt"]
models_file_names = ["config.json", "pytorch.pth", "training_args.bin", "tokenizer/alphabet.pkl"]
logs_path = "logs/lucagplm/%s/%s/%s/%s" % (llm_version, llm_task_level, llm_type, llm_time_str)
models_path = "models/lucagplm/%s/%s/%s/%s/checkpoint-step%s" % (llm_version, llm_task_level, llm_type, llm_time_str, llm_step)
logs_local_dir = os.path.join(llm_dir, logs_path)
exists = True
for logs_file_name in logs_file_names:
if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
exists = False
break
models_local_dir = os.path.join(llm_dir, models_path)
if exists:
for models_file_name in models_file_names:
if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
exists = False
break
if not exists:
print("*" * 20 + "Downloading" + "*" * 20)
print("Downloading LucaOne TrainedCheckPoint: LucaOne-%s-%s-%s ..." % (llm_version, llm_time_str, llm_step))
print("Wait a moment, please.")
# download logs
if not os.path.exists(logs_local_dir):
os.makedirs(logs_local_dir)
logs_base_url = os.path.join(base_url, logs_path)
download_folder(logs_base_url, logs_file_names, logs_local_dir)
# download models
if not os.path.exists(models_local_dir):
os.makedirs(models_local_dir)
models_base_url = os.path.join(base_url, models_path)
download_folder(models_base_url, models_file_names, models_local_dir)
print("LucaOne Download Succeed.")
print("*" * 50)
except Exception as e:
print(e)
print("Download automatically LucaOne Trained CheckPoint failed!")
print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(llm_dir), os.path.join(base_url, "TrainedCheckPoint/")))
raise Exception(e)
def download_trained_checkpoint_downstream_tasks(
save_dir="../",
dataset_name=["CentralDogma", "GenusTax", "InfA", "ncRNAFam", "ncRPI", "PPI", "ProtLoc", "ProtStab", "SpeciesTax", "SupKTax"],
dataset_type=["gene_protein", "gene", "gene_gene", "gene", "gene_protein", "protein", "protein", "protein", "gene", "gene"],
task_type=["binary_class", "multi_class", "binary_class", "multi_class", "binary_class", "binary_class", "multi_class", "regression", "multi_class", "multi_class"],
model_type=["lucappi2", "luca_base", "lucappi", "luca_base", "lucappi2", "lucappi", "luca_base", "luca_base", "luca_base", "luca_base"],
input_type=["matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix"],
time_str=["20240406173806", "20240412100337", "20240214105653", "20240414155526", "20240404105148", "20240216205421", "20240412140824", "20240404104215", "20240411144916", "20240212202328"],
step=[64000, 24500, 9603, 1958484, 716380, 52304, 466005, 70371, 24000, 37000],
base_url="http://47.93.21.181/lucaone/DownstreamTasksTrainedModels"
):
"""
donwload trained downstream task models
:param save_dir: 本地保存路径
:param dataset_name:
:param dataset_type:
:param task_type:
:param model_type:
:param input_type:
:param time_str:
:param step:
:param base_url:
:return:
"""
assert len(dataset_name) == len(dataset_type) == len(task_type) == \
len(model_type) == len(input_type) == len(time_str) == len(step)
assert isinstance(dataset_name, list)
assert isinstance(dataset_type, list)
assert isinstance(task_type, list)
assert isinstance(model_type, list)
assert isinstance(input_type, list)
assert isinstance(time_str, list)
assert isinstance(step, list)
download_succeed_task_num = 0
print("------Download Trained Models------")
for idx in range(len(dataset_name)):
try:
logs_file_names = ["logs.txt", "label.txt"]
models_file_names = ["config.json", "pytorch_model.bin", "training_args.bin", "tokenizer/alphabet.pkl"]
logs_path = "logs/%s/%s/%s/%s/%s/%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx])
models_path = "models/%s/%s/%s/%s/%s/%s/checkpoint-%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx], str(step[idx]))
logs_local_dir = os.path.join(save_dir, logs_path)
exists = True
for logs_file_name in logs_file_names:
if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
exists = False
break
models_local_dir = os.path.join(save_dir, models_path)
if exists:
for models_file_name in models_file_names:
if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
exists = False
break
if not exists:
print("*" * 20 + "Downloading" + "*" * 20)
print("Downloading Downstream Task: %s TrainedCheckPoint: %s-%s-%s ..." % (dataset_name[idx], dataset_name[idx], time_str[idx], str(step[idx])))
print("Wait a moment, please.")
# download logs
if not os.path.exists(logs_local_dir):
os.makedirs(logs_local_dir)
logs_base_url = os.path.join(base_url, dataset_name[idx], logs_path)
download_folder(logs_base_url, logs_file_names, logs_local_dir)
# download models
if not os.path.exists(models_local_dir):
os.makedirs(models_local_dir)
models_base_url = os.path.join(base_url, dataset_name[idx], models_path)
download_folder(models_base_url, models_file_names, models_local_dir)
print("Downstream Task: %s Trained Model Download Succeed." % dataset_name[idx])
print("*" * 50)
download_succeed_task_num += 1
except Exception as e:
print(e)
print("Download automatically LucaDownstream Task: %s Trained CheckPoint failed!" % dataset_name[idx])
print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(save_dir), os.path.join(base_url, dataset_name[idx])))
raise Exception(e)
print("%d Downstream Task Trained Model Download Succeed." % download_succeed_task_num)