import argparse
import requests
import xml.etree.ElementTree as ET
import pickle
import re
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import spacy
import numpy as np
import streamlit as st
from tensorflow.keras.preprocessing.sequence import pad_sequences
STOPWORDS = set(stopwords.words('english'))
max_length = 300
trunc_type = 'post'
padding_type = 'post'

from typing import (
    Dict,
    List,
    Tuple,
    Set,
    Optional,
    Any,
    Union,
)

# Standardize the abstract by replacing all named entities with their entity label.
# Eg. 3 patients reported at a clinic in England --> CARDINAL patients reported at a clinic in GPE
# expects the spaCy model en_core_web_lg as input
def standardizeAbstract(abstract:str, nlp:Any) -> str:
    doc = nlp(abstract)
    newAbstract = abstract
    for e in reversed(doc.ents):
        if e.label_ in {'PERCENT','CARDINAL','GPE','LOC','DATE','TIME','QUANTITY','ORDINAL'}:
            start = e.start_char
            end = start + len(e.text)
            newAbstract = newAbstract[:start] + e.label_ + newAbstract[end:]
    return newAbstract

# Same as above but replaces biomedical named entities from scispaCy models
# Expects as input en_ner_bc5cdr_md and en_ner_bionlp13cg_md
def standardizeSciTerms(abstract:str, nlpSci:Any, nlpSci2:Any) -> str:
    doc = nlpSci(abstract)
    newAbstract = abstract
    for e in reversed(doc.ents):
        start = e.start_char
        end = start + len(e.text)
        newAbstract = newAbstract[:start] + e.label_ + newAbstract[end:]
        
    doc = nlpSci2(newAbstract)
    for e in reversed(doc.ents):
        start = e.start_char
        end = start + len(e.text)
        newAbstract = newAbstract[:start] + e.label_ + newAbstract[end:]
    return newAbstract

# Prepare model
#nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer= init_classify_model()
def init_classify_model(model:str='LSTM_RNN_Model') -> Tuple[Any,Any,Any,Any,Any]:
    #Load spaCy models
    nlp = spacy.load('en_core_web_lg')
    nlpSci = spacy.load("en_ner_bc5cdr_md")
    nlpSci2 = spacy.load('en_ner_bionlp13cg_md')
    
    # load the tokenizer
    with open('tokenizer.pickle', 'rb') as handle:
        classify_tokenizer = pickle.load(handle)
    
    # load the model
    classify_model = tf.keras.models.load_model(model) 
    
    return (nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer)

#Gets abstract and title (concatenated) from EBI API
def PMID_getAb(PMID:Union[int,str]) -> str: 
    url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query=EXT_ID:'+str(PMID)+'&resulttype=core'
    r = requests.get(url)
    root = ET.fromstring(r.content)
    titles = [title.text for title in root.iter('title')]
    abstracts = [abstract.text for abstract in root.iter('abstractText')]
    if len(abstracts) > 0 and len(abstracts[0])>5:
        return titles[0]+' '+abstracts[0]
    else:
        return ''

def search_Pubmed_API(searchterm_list:Union[List[str],str], maxResults:int) -> Dict[str,str]: #returns a dictionary of {pmids:abstracts} 
    print('search_Pubmed_API is DEPRECATED. UTILIZE search_NCBI_API for NCBI ENTREZ API results. Utilize search_getAbs for most comprehensive results.')
    return search_NCBI_API(searchterm_list, maxResults)
    
## DEPRECATED, use search_getAbs for more comprehensive results
def search_NCBI_API(searchterm_list:Union[List[str],str], maxResults:int) -> Dict[str,str]: #returns a dictionary of {pmids:abstracts} 
    print('search_NCBI_API is DEPRECATED. Utilize search_getAbs for most comprehensive results.')
    pmid_to_abs = {}
    i = 0
    
    #type validation, allows string or list input
    if type(searchterm_list)!=list:
        if type(searchterm_list)==str:
            searchterm_list = [searchterm_list]
        else:
            searchterm_list = list(searchterm_list)
    
    #gathers pmids into a set first
    for dz in searchterm_list:
        # get results from searching for disease name through PubMed API
        term = ''
        dz_words = dz.split()
        for word in dz_words:
            term += word + '%20'
        query = term[:-3]
        url = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term='+query
        r = requests.get(url)
        root = ET.fromstring(r.content)    

        # loop over resulting articles
        for result in root.iter('IdList'):
            pmids = [pmid.text for pmid in result.iter('Id')]
            if i >= maxResults:
                break
            for pmid in pmids:
                if pmid not in pmid_to_abs.keys():
                    abstract = PMID_getAb(pmid)
                    if len(abstract)>5:
                        pmid_to_abs[pmid]=abstract
                        i+=1
                    
    return pmid_to_abs

## DEPRECATED, use search_getAbs for more comprehensive results
# get results from searching for disease name through EBI API
def search_EBI_API(searchterm_list:Union[List[str],str], maxResults:int) -> Dict[str,str]: #returns a dictionary of {pmids:abstracts}    
    print('DEPRECATED. Utilize search_getAbs for most comprehensive results.')
    pmids_abs = {}
    i = 0
    
    #type validation, allows string or list input
    if type(searchterm_list)!=list:
        if type(searchterm_list)==str:
            searchterm_list = [searchterm_list]
        else:
            searchterm_list = list(searchterm_list)
    
    #gathers pmids into a set first
    for dz in searchterm_list:
        if i >= maxResults:
            break
        term = ''
        dz_words = dz.split()
        for word in dz_words:
            term += word + '%20'
        query = term[:-3]
        url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query='+query+'&resulttype=core'
        r = requests.get(url)
        root = ET.fromstring(r.content)

        # loop over resulting articles
        for result in root.iter('result'):
            if i >= maxResults:
                break
            pmids = [pmid.text for pmid in result.iter('id')]
            if len(pmids) > 0:
                pmid = pmids[0]
                if pmid[0].isdigit():
                    abstracts = [abstract.text for abstract in result.iter('abstractText')]
                    titles = [title.text for title in result.iter('title')]
                    if len(abstracts) > 0:# and len(abstracts[0])>5:
                        pmids_abs[pmid] = titles[0]+' '+abstracts[0]
                        i+=1
    
    return pmids_abs

## This is the main, most comprehensive search_term function, it can take in a search term or a list of search terms and output a dictionary of {pmids:abstracts}
## Gets results from searching through both PubMed and EBI search term APIs, also makes use of the EBI API for PMIDs. 
## EBI API and PubMed API give different results
# This makes n+2 API calls where n<=maxResults, which is slow 
# There is a way to optimize by gathering abstracts from the EBI API when also getting pmids but did not pursue due to time constraints
# Filtering can be 
#   'strict' - must have some exact match to at leastone of search terms/phrases in text)
#   'lenient' - part of the abstract must match at least one word in the search term phrases.
#   'none'
def search_getAbs(searchterm_list:Union[List[str],List[int],str], maxResults:int, filtering:str) -> Dict[str,str]:
    #set of all pmids
    pmids = set()
    
    #dictionary {pmid:abstract}
    pmid_abs = {}
    
    #type validation, allows string or list input
    if type(searchterm_list)!=list:
        if type(searchterm_list)==str:
            searchterm_list = [searchterm_list]
        else:
            searchterm_list = list(searchterm_list)
    
    #gathers pmids into a set first
    for dz in searchterm_list:
        term = ''
        dz_words = dz.split()
        for word in dz_words:
            term += word + '%20'
        query = term[:-3]

        ## get pmid results from searching for disease name through PubMed API
        url = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term='+query
        r = requests.get(url)
        root = ET.fromstring(r.content)

        # loop over resulting articles
        for result in root.iter('IdList'):
            if len(pmids) >= maxResults:
                break
            pmidlist = [pmid.text for pmid in result.iter('Id')]
            pmids.update(pmidlist)

        ## get results from searching for disease name through EBI API
        url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query='+query+'&resulttype=core'
        r = requests.get(url)
        root = ET.fromstring(r.content)

        # loop over resulting articles
        for result in root.iter('result'):
            if len(pmids) >= maxResults:
                break
            pmidlist = [pmid.text for pmid in result.iter('id')]
            #can also gather abstract and title here but for some reason did not work as intended the first time. Optimize in future versions to reduce latency.
            if len(pmidlist) > 0:
                pmid = pmidlist[0]
                if pmid[0].isdigit():
                    pmids.add(pmid)
    
    #Construct sets for filtering (right before adding abstract to pmid_abs
    # The purpose of this is to do a second check of the abstracts, filters out any abstracts unrelated to the search terms
    #if filtering is 'lenient' or default
    if filtering !='none' or filtering !='strict':
        filter_terms = set(searchterm_list).union(set(str(re.sub(',','',' '.join(searchterm_list))).split()).difference(STOPWORDS))
        '''
        # The above is equivalent to this but uses less memory and may be faster:
        #create a single string of the terms within the searchterm_list
        joined = ' '.join(searchterm_list)
        #remove commas
        comma_gone = re.sub(',','',joined)
        #split the string into list of words and convert list into a Pythonic set
        split = set(comma_gone.split())
        #remove the STOPWORDS from the set of key words
        key_words = split.difference(STOPWORDS)
        #create a new set of the list members in searchterm_list
        search_set = set(searchterm_list)
        #join the two sets
        terms = search_set.union(key_words)
        #if any word(s) in the abstract intersect with any of these terms then the abstract is good to go.
        '''
    
    ## get abstracts from EBI PMID API and output a dictionary
    for pmid in pmids:
        abstract = PMID_getAb(pmid)
        if len(abstract)>5:
            #do filtering here
            if filtering == 'strict':
                uncased_ab = abstract.lower()
                for term in searchterm_list:
                    if term.lower() in uncased_ab:
                        pmid_abs[pmid] = abstract
                        break
            elif filtering =='none':
                pmid_abs[pmid] = abstract
            
            #Default filtering is 'lenient'.
            else:
                #Else and if are separated for readability and to better understand logical flow.
                if set(filter_terms).intersection(set(word_tokenize(abstract))):
                    pmid_abs[pmid] = abstract
                
                    
    print('Found',len(pmids),'PMIDs. Gathered',len(pmid_abs),'Relevant Abstracts.')
    
    return pmid_abs

#This is a streamlit version of search_getAbs. Refer to search_getAbs for documentation
def streamlit_getAbs(searchterm_list:Union[List[str],List[int],str], maxResults:int, filtering:str) -> Dict[str,str]:
    pmids = set()
    
    pmid_abs = {}
    
    if type(searchterm_list)!=list:
        if type(searchterm_list)==str:
            searchterm_list = [searchterm_list]
        else:
            searchterm_list = list(searchterm_list)
    #maxResults is multiplied by a little bit because sometimes the results returned is more than maxResults
    percent_by_step = 1/maxResults
    with st.spinner("Gathering PubMed IDs..."):
        PMIDs_bar = st.progress(0)
        for dz in searchterm_list:
            term = ''
            dz_words = dz.split()
            for word in dz_words:
                term += word + '%20'
            query = term[:-3]
    
            url = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term='+query
            r = requests.get(url)
            root = ET.fromstring(r.content)
    
            for result in root.iter('IdList'):
                for pmid in result.iter('Id'):
                    if len(pmids) >= maxResults:
                        break
                    pmids.add(pmid.text)
                    PMIDs_bar.progress(min(round(len(pmids)*percent_by_step,1),1.0))
    
            url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query='+query+'&resulttype=core'
            r = requests.get(url)
            root = ET.fromstring(r.content)
    
            for result in root.iter('result'):
                if len(pmids) >= maxResults:
                    break
                pmidlist = [pmid.text for pmid in result.iter('id')]
                if len(pmidlist) > 0:
                    pmid = pmidlist[0]
                    if pmid[0].isdigit():
                        pmids.add(pmid)
                        PMIDs_bar.progress(min(round(len(pmids)*percent_by_step,1),1.0))
        PMIDs_bar.empty()
    
    with st.spinner("Found "+str(len(pmids))+" PMIDs. Gathering Abstracts and Filtering..."):
        abstracts_bar = st.progress(0)
        percent_by_step = 1/maxResults
        if filtering !='none' or filtering !='strict':
            filter_terms = set(searchterm_list).union(set(str(re.sub(',','',' '.join(searchterm_list))).split()).difference(STOPWORDS))
    
        for pmid in pmids:
            abstract = PMID_getAb(pmid)
            if len(abstract)>5:
                #do filtering here
                if filtering == 'strict':
                    uncased_ab = abstract.lower()
                    for term in searchterm_list:
                        if term.lower() in uncased_ab:
                            pmid_abs[pmid] = abstract                            
                            break
                elif filtering =='none':
                    pmid_abs[pmid] = abstract
                #Default filtering is 'lenient'.
                else:
                    #Else and if are separated for readability and to better understand logical flow.
                    if set(filter_terms).intersection(set(word_tokenize(abstract))):
                        pmid_abs[pmid] = abstract
            abstracts_bar.progress(min(round(len(pmid_abs)*percent_by_step,1),1.0))
        abstracts_bar.empty()
    found = len(pmids)
    relevant = len(pmid_abs)
    st.success('Found '+str(found)+' PMIDs. Gathered '+str(relevant)+' Relevant Abstracts. Classifying and extracting epidemiology information...')
    
    return pmid_abs, (found, relevant)

# Generate predictions for a PubMed Id
# nlp: en_core_web_lg
# nlpSci: en_ner_bc5cdr_md
# nlpSci2: en_ner_bionlp13cg_md
# Defaults to load my_model_orphanet_final, the most up-to-date version of the classification model,
# but can also be run on any other tf.keras model
#This was originally getPredictions
def getPMIDPredictions(pmid:Union[str,int], classify_model_vars:Tuple[Any,Any,Any,Any,Any]) -> Tuple[str,float,bool]:
    nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer = classify_model_vars
    abstract = PMID_getAb(pmid)
    
    if len(abstract)>5:
        # remove stopwords
        for word in STOPWORDS:
            token = ' ' + word + ' '
            abstract = abstract.replace(token, ' ')
            abstract = abstract.replace(' ', ' ')

        # preprocess abstract
        abstract_standard = [standardizeAbstract(standardizeSciTerms(abstract, nlpSci, nlpSci2), nlp)]
        sequence = classify_tokenizer.texts_to_sequences(abstract_standard)
        padded = pad_sequences(sequence, maxlen=max_length, padding=padding_type, truncating=trunc_type)

        y_pred1 = classify_model.predict(padded) # generate prediction
        y_pred = np.argmax(y_pred1, axis=1) # get binary prediction

        prob = y_pred1[0][1]
        if y_pred == 1:
            isEpi = True
        else:
            isEpi = False

        return abstract, prob, isEpi
    
    else:
        return abstract, 0.0, False


def getTextPredictions(abstract:str, classify_model_vars:Tuple[Any,Any,Any,Any,Any]) -> Tuple[float,bool]:
    
    nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer = classify_model_vars
    
    if len(abstract)>5:
        # remove stopwords
        for word in STOPWORDS:
            token = ' ' + word + ' '
            abstract = abstract.replace(token, ' ')
            abstract = abstract.replace(' ', ' ')

        # preprocess abstract
        abstract_standard = [standardizeAbstract(standardizeSciTerms(abstract, nlpSci, nlpSci2), nlp)]
        sequence = classify_tokenizer.texts_to_sequences(abstract_standard)
        padded = pad_sequences(sequence, maxlen=max_length, padding=padding_type, truncating=trunc_type)

        y_pred1 = classify_model.predict(padded) # generate prediction
        y_pred = np.argmax(y_pred1, axis=1) # get binary prediction

        prob = y_pred1[0][1]
        if y_pred == 1:
            isEpi = True
        else:
            isEpi = False

        return prob, isEpi
    
    else:
        return 0.0, False