File size: 2,584 Bytes
6d06448
 
 
 
 
 
 
 
e86736e
6d06448
 
 
 
e86736e
 
 
 
 
6d06448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e86736e
6d06448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#PSUEDOCODE UNTIL WE GET DATA
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from datasets import load_dataset
import pickle


def infer_dna(args):
    ecoDf = pd.read_csv(args['input_path'], sep='\t')
    dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train')

    # load model to calculate new embeddings
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True)
    tokenizer.add_special_tokens({"pad_token": "<UNK>"})
    bert_model = BertForMaskedLM.from_pretrained(model, force_download=True)

    modelDNA = load_checkpoint()
    modelDNAEnv = load_checkpoint()

    ecoDF = ecoDf[ecoDf['marker_code' == 'COI-5P']]
    ecoDf = ecoDf[['processid','nucraw','coord','country','depth',
       'WorldClim2_BIO_Temperature_Seasonality',
       'WorldClim2_BIO_Precipitation_Seasonality','WorldClim2_BIO_Annual_Precipitation', 'EarthEnvTopoMed_Elevation',
       'EsaWorldCover_TreeCover', 'CHELSA_exBIO_GrowingSeasonLength',
       'WCS_Human_Footprint_2009', 'GHS_Population_Density',
       'CHELSA_BIO_Annual_Mean_Temperature']]

    # grab DNA embeddings and merge them onto ecoDf by processid
    X_eco = pd.merge(ecoDf, dnaEmbeds, on='processid', how='left')


    # split data into X and y
    # X = df.drop(columns=['genus'])
    Y_eco = ecoDf['genus']

    # do inference with the model trained on DNA and Env data
    y_eco_probs = modelDNA.predict_proba(X_eco)
    # topProb = np.argsort(y_probs, axis=1)[:,-3:]
    # topClass = dnamodel.classes_[topProb]

    DNAGenuses = {}
    for i in range(len(X_eco)):
        topProbs = np.argsort(y_probs[i], axis=1)[:,-3:]
        topClasses = modelDNA.classes_[topProbs]

        sampleStr = X_eco['nucraw'][i]
        DNAGenuses[sampleStr] = (topClasses, topProbs)


    X_dna = dnaEmbeds.drop(columns='genus')
    Y_dna = dnaEmbeds['genus']
    # do inferences with the model only trained on DNA
    y_dna_probs = modelDNAEnv.predict_proba(X_dna)
    DNAEnvGenuses = {}
    for i in range(len()):

        topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:]
        topClasses = modelDNA.classes_[topProbs]

        sampleStr = X_eco['nucraw'][i]
        DNAGenuses[sampleStr] = (topClasses, topProbs)

    return DNAGenuses, DNAEnvGenuses


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--input_path', action='store', type=str)
#     # parser.add_argument('--checkpt', action='store', type=bool, default=False)

#     args = vars(parser.parse_args())