wzkariampuzha commited on
Commit
aa32937
·
1 Parent(s): f21365b

Upload classify_abs.py

Browse files
Files changed (1) hide show
  1. classify_abs.py +356 -0
classify_abs.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import xml.etree.ElementTree as ET
4
+ import pickle
5
+ import re
6
+ import os
7
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
8
+ import tensorflow as tf
9
+ from nltk.corpus import stopwords
10
+ from nltk.tokenize import word_tokenize
11
+ import spacy
12
+ import numpy as np
13
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
14
+ STOPWORDS = set(stopwords.words('english'))
15
+ max_length = 300
16
+ trunc_type = 'post'
17
+ padding_type = 'post'
18
+
19
+ from typing import (
20
+ Dict,
21
+ List,
22
+ Tuple,
23
+ Set,
24
+ Optional,
25
+ Any,
26
+ Union,
27
+ )
28
+
29
+ # Standardize the abstract by replacing all named entities with their entity label.
30
+ # Eg. 3 patients reported at a clinic in England --> CARDINAL patients reported at a clinic in GPE
31
+ # expects the spaCy model en_core_web_lg as input
32
+ def standardizeAbstract(abstract:str, nlp:Any) -> str:
33
+ doc = nlp(abstract)
34
+ newAbstract = abstract
35
+ for e in reversed(doc.ents):
36
+ if e.label_ in {'PERCENT','CARDINAL','GPE','LOC','DATE','TIME','QUANTITY','ORDINAL'}:
37
+ start = e.start_char
38
+ end = start + len(e.text)
39
+ newAbstract = newAbstract[:start] + e.label_ + newAbstract[end:]
40
+ return newAbstract
41
+
42
+ # Same as above but replaces biomedical named entities from scispaCy models
43
+ # Expects as input en_ner_bc5cdr_md and en_ner_bionlp13cg_md
44
+ def standardizeSciTerms(abstract:str, nlpSci:Any, nlpSci2:Any) -> str:
45
+ doc = nlpSci(abstract)
46
+ newAbstract = abstract
47
+ for e in reversed(doc.ents):
48
+ start = e.start_char
49
+ end = start + len(e.text)
50
+ newAbstract = newAbstract[:start] + e.label_ + newAbstract[end:]
51
+
52
+ doc = nlpSci2(newAbstract)
53
+ for e in reversed(doc.ents):
54
+ start = e.start_char
55
+ end = start + len(e.text)
56
+ newAbstract = newAbstract[:start] + e.label_ + newAbstract[end:]
57
+ return newAbstract
58
+
59
+ # Prepare model
60
+ #nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer= init_classify_model()
61
+ def init_classify_model(model:str='my_model_orphanet_final') -> Tuple[Any,Any,Any,Any,Any]:
62
+ #Load spaCy models
63
+ nlp = spacy.load('en_core_web_lg')
64
+ nlpSci = spacy.load("en_ner_bc5cdr_md")
65
+ nlpSci2 = spacy.load('en_ner_bionlp13cg_md')
66
+
67
+ # load the tokenizer
68
+ with open('tokenizer.pickle', 'rb') as handle:
69
+ classify_tokenizer = pickle.load(handle)
70
+
71
+ # load the model
72
+ classify_model = tf.keras.models.load_model(model)
73
+
74
+ return (nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer)
75
+
76
+ #Gets abstract and title (concatenated) from EBI API
77
+ def PMID_getAb(PMID:Union[int,str]) -> str:
78
+ url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query=EXT_ID:'+str(PMID)+'&resulttype=core'
79
+ r = requests.get(url)
80
+ root = ET.fromstring(r.content)
81
+ titles = [title.text for title in root.iter('title')]
82
+ abstracts = [abstract.text for abstract in root.iter('abstractText')]
83
+ if len(abstracts) > 0 and len(abstracts[0])>5:
84
+ return titles[0]+' '+abstracts[0]
85
+ else:
86
+ return ''
87
+
88
+ def search_Pubmed_API(searchterm_list:Union[List[str],str], maxResults:int) -> Dict[str,str]: #returns a dictionary of {pmids:abstracts}
89
+ print('search_Pubmed_API is DEPRECATED. UTILIZE search_NCBI_API for NCBI ENTREZ API results. Utilize search_getAbs for most comprehensive results.')
90
+ return search_NCBI_API(searchterm_list, maxResults)
91
+
92
+ ## DEPRECATED, use search_getAbs for more comprehensive results
93
+ def search_NCBI_API(searchterm_list:Union[List[str],str], maxResults:int) -> Dict[str,str]: #returns a dictionary of {pmids:abstracts}
94
+ print('search_NCBI_API is DEPRECATED. Utilize search_getAbs for most comprehensive results.')
95
+ pmid_to_abs = {}
96
+ i = 0
97
+
98
+ #type validation, allows string or list input
99
+ if type(searchterm_list)!=list:
100
+ if type(searchterm_list)==str:
101
+ searchterm_list = [searchterm_list]
102
+ else:
103
+ searchterm_list = list(searchterm_list)
104
+
105
+ #gathers pmids into a set first
106
+ for dz in searchterm_list:
107
+ # get results from searching for disease name through PubMed API
108
+ term = ''
109
+ dz_words = dz.split()
110
+ for word in dz_words:
111
+ term += word + '%20'
112
+ query = term[:-3]
113
+ url = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term='+query
114
+ r = requests.get(url)
115
+ root = ET.fromstring(r.content)
116
+
117
+ # loop over resulting articles
118
+ for result in root.iter('IdList'):
119
+ pmids = [pmid.text for pmid in result.iter('Id')]
120
+ if i >= maxResults:
121
+ break
122
+ for pmid in pmids:
123
+ if pmid not in pmid_to_abs.keys():
124
+ abstract = PMID_getAb(pmid)
125
+ if len(abstract)>5:
126
+ pmid_to_abs[pmid]=abstract
127
+ i+=1
128
+
129
+ return pmid_to_abs
130
+
131
+ ## DEPRECATED, use search_getAbs for more comprehensive results
132
+ # get results from searching for disease name through EBI API
133
+ def search_EBI_API(searchterm_list:Union[List[str],str], maxResults:int) -> Dict[str,str]: #returns a dictionary of {pmids:abstracts}
134
+ print('DEPRECATED. Utilize search_getAbs for most comprehensive results.')
135
+ pmids_abs = {}
136
+ i = 0
137
+
138
+ #type validation, allows string or list input
139
+ if type(searchterm_list)!=list:
140
+ if type(searchterm_list)==str:
141
+ searchterm_list = [searchterm_list]
142
+ else:
143
+ searchterm_list = list(searchterm_list)
144
+
145
+ #gathers pmids into a set first
146
+ for dz in searchterm_list:
147
+ if i >= maxResults:
148
+ break
149
+ term = ''
150
+ dz_words = dz.split()
151
+ for word in dz_words:
152
+ term += word + '%20'
153
+ query = term[:-3]
154
+ url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query='+query+'&resulttype=core'
155
+ r = requests.get(url)
156
+ root = ET.fromstring(r.content)
157
+
158
+ # loop over resulting articles
159
+ for result in root.iter('result'):
160
+ if i >= maxResults:
161
+ break
162
+ pmids = [pmid.text for pmid in result.iter('id')]
163
+ if len(pmids) > 0:
164
+ pmid = pmids[0]
165
+ if pmid[0].isdigit():
166
+ abstracts = [abstract.text for abstract in result.iter('abstractText')]
167
+ titles = [title.text for title in result.iter('title')]
168
+ if len(abstracts) > 0:# and len(abstracts[0])>5:
169
+ pmids_abs[pmid] = titles[0]+' '+abstracts[0]
170
+ i+=1
171
+
172
+ return pmids_abs
173
+
174
+ ## This is the main, most comprehensive search_term function, it can take in a search term or a list of search terms and output a dictionary of {pmids:abstracts}
175
+ ## Gets results from searching through both PubMed and EBI search term APIs, also makes use of the EBI API for PMIDs.
176
+ ## EBI API and PubMed API give different results
177
+ # This makes n+2 API calls where n<=maxResults, which is slow
178
+ # There is a way to optimize by gathering abstracts from the EBI API when also getting pmids but did not pursue due to time constraints
179
+ # Filtering can be
180
+ # 'strict' - must have some exact match to at leastone of search terms/phrases in text)
181
+ # 'lenient' - part of the abstract must match at least one word in the search term phrases.
182
+ # 'none'
183
+ def search_getAbs(searchterm_list:Union[List[str],List[int],str], maxResults:int, filtering:str) -> Dict[str,str]:
184
+ #set of all pmids
185
+ pmids = set()
186
+
187
+ #dictionary {pmid:abstract}
188
+ pmid_abs = {}
189
+
190
+ #type validation, allows string or list input
191
+ if type(searchterm_list)!=list:
192
+ if type(searchterm_list)==str:
193
+ searchterm_list = [searchterm_list]
194
+ else:
195
+ searchterm_list = list(searchterm_list)
196
+
197
+ #gathers pmids into a set first
198
+ for dz in searchterm_list:
199
+ term = ''
200
+ dz_words = dz.split()
201
+ for word in dz_words:
202
+ term += word + '%20'
203
+ query = term[:-3]
204
+
205
+ ## get pmid results from searching for disease name through PubMed API
206
+ url = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term='+query
207
+ r = requests.get(url)
208
+ root = ET.fromstring(r.content)
209
+
210
+ # loop over resulting articles
211
+ for result in root.iter('IdList'):
212
+ if len(pmids) >= maxResults:
213
+ break
214
+ pmidlist = [pmid.text for pmid in result.iter('Id')]
215
+ pmids.update(pmidlist)
216
+
217
+ ## get results from searching for disease name through EBI API
218
+ url = 'https://www.ebi.ac.uk/europepmc/webservices/rest/search?query='+query+'&resulttype=core'
219
+ r = requests.get(url)
220
+ root = ET.fromstring(r.content)
221
+
222
+ # loop over resulting articles
223
+ for result in root.iter('result'):
224
+ if len(pmids) >= maxResults:
225
+ break
226
+ pmidlist = [pmid.text for pmid in result.iter('id')]
227
+ #can also gather abstract and title here but for some reason did not work as intended the first time. Optimize in future versions to reduce latency.
228
+ if len(pmidlist) > 0:
229
+ pmid = pmidlist[0]
230
+ if pmid[0].isdigit():
231
+ pmids.add(pmid)
232
+
233
+ #Construct sets for filtering (right before adding abstract to pmid_abs
234
+ # The purpose of this is to do a second check of the abstracts, filters out any abstracts unrelated to the search terms
235
+ #if filtering is 'lenient' or default
236
+ if filtering !='none' or filtering !='strict':
237
+ filter_terms = set(searchterm_list).union(set(str(re.sub(',','',' '.join(searchterm_list))).split()).difference(STOPWORDS))
238
+ '''
239
+ # The above is equivalent to this but uses less memory and may be faster:
240
+ #create a single string of the terms within the searchterm_list
241
+ joined = ' '.join(searchterm_list)
242
+ #remove commas
243
+ comma_gone = re.sub(',','',joined)
244
+ #split the string into list of words and convert list into a Pythonic set
245
+ split = set(comma_gone.split())
246
+ #remove the STOPWORDS from the set of key words
247
+ key_words = split.difference(STOPWORDS)
248
+ #create a new set of the list members in searchterm_list
249
+ search_set = set(searchterm_list)
250
+ #join the two sets
251
+ terms = search_set.union(key_words)
252
+ #if any word(s) in the abstract intersect with any of these terms then the abstract is good to go.
253
+ '''
254
+
255
+ ## get abstracts from EBI PMID API and output a dictionary
256
+ for pmid in pmids:
257
+ abstract = PMID_getAb(pmid)
258
+ if len(abstract)>5:
259
+ #do filtering here
260
+ if filtering == 'strict':
261
+ uncased_ab = abstract.lower()
262
+ for term in searchterm_list:
263
+ if term.lower() in uncased_ab:
264
+ pmid_abs[pmid] = abstract
265
+ break
266
+ elif filtering =='none':
267
+ pmid_abs[pmid] = abstract
268
+
269
+ #Default filtering is 'lenient'.
270
+ else:
271
+ #Else and if are separated for readability and to better understand logical flow.
272
+ if set(filter_terms).intersection(set(word_tokenize(abstract))):
273
+ pmid_abs[pmid] = abstract
274
+
275
+
276
+ print('Found',len(pmids),'PMIDs. Gathered',len(pmid_abs),'Relevant Abstracts.')
277
+
278
+ return pmid_abs
279
+
280
+ # Generate predictions for a PubMed Id
281
+ # nlp: en_core_web_lg
282
+ # nlpSci: en_ner_bc5cdr_md
283
+ # nlpSci2: en_ner_bionlp13cg_md
284
+ # Defaults to load my_model_orphanet_final, the most up-to-date version of the classification model,
285
+ # but can also be run on any other tf.keras model
286
+ #This was originally getPredictions
287
+ def getPMIDPredictions(pmid:Union[str,int], classify_model_vars:Tuple[Any,Any,Any,Any,Any]) -> Tuple[str,float,bool]:
288
+ nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer = classify_model_vars
289
+ abstract = PMID_getAb(pmid)
290
+
291
+ if len(abstract)>5:
292
+ # remove stopwords
293
+ for word in STOPWORDS:
294
+ token = ' ' + word + ' '
295
+ abstract = abstract.replace(token, ' ')
296
+ abstract = abstract.replace(' ', ' ')
297
+
298
+ # preprocess abstract
299
+ abstract_standard = [standardizeAbstract(standardizeSciTerms(abstract, nlpSci, nlpSci2), nlp)]
300
+ sequence = classify_tokenizer.texts_to_sequences(abstract_standard)
301
+ padded = pad_sequences(sequence, maxlen=max_length, padding=padding_type, truncating=trunc_type)
302
+
303
+ y_pred1 = classify_model.predict(padded) # generate prediction
304
+ y_pred = np.argmax(y_pred1, axis=1) # get binary prediction
305
+
306
+ prob = y_pred1[0][1]
307
+ if y_pred == 1:
308
+ isEpi = True
309
+ else:
310
+ isEpi = False
311
+
312
+ return abstract, prob, isEpi
313
+
314
+ else:
315
+ return abstract, 0.0, False
316
+
317
+
318
+ def getTextPredictions(abstract:str, classify_model_vars:Tuple[Any,Any,Any,Any,Any]) -> Tuple[float,bool]:
319
+
320
+ nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer = classify_model_vars
321
+
322
+ if len(abstract)>5:
323
+ # remove stopwords
324
+ for word in STOPWORDS:
325
+ token = ' ' + word + ' '
326
+ abstract = abstract.replace(token, ' ')
327
+ abstract = abstract.replace(' ', ' ')
328
+
329
+ # preprocess abstract
330
+ abstract_standard = [standardizeAbstract(standardizeSciTerms(abstract, nlpSci, nlpSci2), nlp)]
331
+ sequence = classify_tokenizer.texts_to_sequences(abstract_standard)
332
+ padded = pad_sequences(sequence, maxlen=max_length, padding=padding_type, truncating=trunc_type)
333
+
334
+ y_pred1 = classify_model.predict(padded) # generate prediction
335
+ y_pred = np.argmax(y_pred1, axis=1) # get binary prediction
336
+
337
+ prob = y_pred1[0][1]
338
+ if y_pred == 1:
339
+ isEpi = True
340
+ else:
341
+ isEpi = False
342
+
343
+ return prob, isEpi
344
+
345
+ else:
346
+ return 0.0, False
347
+
348
+ if __name__ == '__main__':
349
+ print('Loading 5 NLP models...')
350
+ classify_model_vars= init_classify_model()
351
+ print('All models loaded.')
352
+ pmid = input('\nEnter PubMed PMID (or DONE): ')
353
+ while pmid != 'DONE':
354
+ abstract, prob, isEpi = getPredictions(pmid, classify_model_vars)
355
+ print(abstract, prob, isEpi)
356
+ pmid = input('\nEnter PubMed PMID (or DONE): ')