Spaces:
Runtime error
Runtime error
File size: 2,584 Bytes
6d06448 e86736e 6d06448 e86736e 6d06448 e86736e 6d06448 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
#PSUEDOCODE UNTIL WE GET DATA
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from datasets import load_dataset
import pickle
def infer_dna(args):
ecoDf = pd.read_csv(args['input_path'], sep='\t')
dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train')
# load model to calculate new embeddings
tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True)
tokenizer.add_special_tokens({"pad_token": "<UNK>"})
bert_model = BertForMaskedLM.from_pretrained(model, force_download=True)
modelDNA = load_checkpoint()
modelDNAEnv = load_checkpoint()
ecoDF = ecoDf[ecoDf['marker_code' == 'COI-5P']]
ecoDf = ecoDf[['processid','nucraw','coord','country','depth',
'WorldClim2_BIO_Temperature_Seasonality',
'WorldClim2_BIO_Precipitation_Seasonality','WorldClim2_BIO_Annual_Precipitation', 'EarthEnvTopoMed_Elevation',
'EsaWorldCover_TreeCover', 'CHELSA_exBIO_GrowingSeasonLength',
'WCS_Human_Footprint_2009', 'GHS_Population_Density',
'CHELSA_BIO_Annual_Mean_Temperature']]
# grab DNA embeddings and merge them onto ecoDf by processid
X_eco = pd.merge(ecoDf, dnaEmbeds, on='processid', how='left')
# split data into X and y
# X = df.drop(columns=['genus'])
Y_eco = ecoDf['genus']
# do inference with the model trained on DNA and Env data
y_eco_probs = modelDNA.predict_proba(X_eco)
# topProb = np.argsort(y_probs, axis=1)[:,-3:]
# topClass = dnamodel.classes_[topProb]
DNAGenuses = {}
for i in range(len(X_eco)):
topProbs = np.argsort(y_probs[i], axis=1)[:,-3:]
topClasses = modelDNA.classes_[topProbs]
sampleStr = X_eco['nucraw'][i]
DNAGenuses[sampleStr] = (topClasses, topProbs)
X_dna = dnaEmbeds.drop(columns='genus')
Y_dna = dnaEmbeds['genus']
# do inferences with the model only trained on DNA
y_dna_probs = modelDNAEnv.predict_proba(X_dna)
DNAEnvGenuses = {}
for i in range(len()):
topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:]
topClasses = modelDNA.classes_[topProbs]
sampleStr = X_eco['nucraw'][i]
DNAGenuses[sampleStr] = (topClasses, topProbs)
return DNAGenuses, DNAEnvGenuses
# if __name__ == '__main__':
# parser = argparse.ArgumentParser()
# parser.add_argument('--input_path', action='store', type=str)
# # parser.add_argument('--checkpt', action='store', type=bool, default=False)
# args = vars(parser.parse_args()) |