#PSUEDOCODE UNTIL WE GET DATA import pandas as pd import numpy as np from sklearn.metrics import accuracy_score from sklearn.preprocessing import LabelEncoder from datasets import load_dataset import pickle def infer_dna(args): ecoDf = pd.read_csv(args['input_path'], sep='\t') dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train') # load model to calculate new embeddings tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True) tokenizer.add_special_tokens({"pad_token": ""}) bert_model = BertForMaskedLM.from_pretrained(model, force_download=True) modelDNA = load_checkpoint() modelDNAEnv = load_checkpoint() ecoDF = ecoDf[ecoDf['marker_code' == 'COI-5P']] ecoDf = ecoDf[['processid','nucraw','coord','country','depth', 'WorldClim2_BIO_Temperature_Seasonality', 'WorldClim2_BIO_Precipitation_Seasonality','WorldClim2_BIO_Annual_Precipitation', 'EarthEnvTopoMed_Elevation', 'EsaWorldCover_TreeCover', 'CHELSA_exBIO_GrowingSeasonLength', 'WCS_Human_Footprint_2009', 'GHS_Population_Density', 'CHELSA_BIO_Annual_Mean_Temperature']] # grab DNA embeddings and merge them onto ecoDf by processid X_eco = pd.merge(ecoDf, dnaEmbeds, on='processid', how='left') # split data into X and y # X = df.drop(columns=['genus']) Y_eco = ecoDf['genus'] # do inference with the model trained on DNA and Env data y_eco_probs = modelDNA.predict_proba(X_eco) # topProb = np.argsort(y_probs, axis=1)[:,-3:] # topClass = dnamodel.classes_[topProb] DNAGenuses = {} for i in range(len(X_eco)): topProbs = np.argsort(y_probs[i], axis=1)[:,-3:] topClasses = modelDNA.classes_[topProbs] sampleStr = X_eco['nucraw'][i] DNAGenuses[sampleStr] = (topClasses, topProbs) X_dna = dnaEmbeds.drop(columns='genus') Y_dna = dnaEmbeds['genus'] # do inferences with the model only trained on DNA y_dna_probs = modelDNAEnv.predict_proba(X_dna) DNAEnvGenuses = {} for i in range(len()): topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:] topClasses = modelDNA.classes_[topProbs] sampleStr = X_eco['nucraw'][i] DNAGenuses[sampleStr] = (topClasses, topProbs) return DNAGenuses, DNAEnvGenuses # if __name__ == '__main__': # parser = argparse.ArgumentParser() # parser.add_argument('--input_path', action='store', type=str) # # parser.add_argument('--checkpt', action='store', type=bool, default=False) # args = vars(parser.parse_args())