import streamlit as st from transformers import pipeline, AutoTokenizer, LEDForConditionalGeneration import requests from bs4 import BeautifulSoup import nltk import string from streamlit.components.v1 import html from sentence_transformers.cross_encoder import CrossEncoder as CE import re from typing import List, Tuple import torch SCITE_API_KEY = st.secrets["SCITE_API_KEY"] # class CrossEncoder: # def __init__(self, model_path: str, **kwargs): # self.model = CE(model_path, **kwargs) # def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]: # return self.model.predict( # sentences=sentences, # batch_size=batch_size, # show_progress_bar=show_progress_bar) def remove_html(x): soup = BeautifulSoup(x, 'html.parser') text = soup.get_text() return text.strip() # 4 searches: strict y/n, supported y/n # deduplicate # search per query # options are abstract search # all search def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False): term = clean_query(term, clean=clean, strict=strict) # heuristic, 2 searches strict and not? and then merge? # https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true contexts, docs = [], [] if not abstract_only: mode = 'all' if not all_mode: mode = 'citations' search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) try: req.json() except: pass contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']] docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '') for doc in req.json()['hits']] if abstracts or abstract_only: search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) try: req.json() contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']] docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '') for doc in req.json()['hits']] except: pass return ( contexts, docs ) def find_source(text, docs, matched): for doc in docs: for snippet in doc[1]: if text in remove_html(snippet.get('snippet', '')): if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip(): continue new_text = text for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))): if text in sent: new_text = sent return { 'citation_statement': snippet['snippet'].replace('', '').replace('', ''), 'text': new_text, 'from': snippet['source'], 'supporting': snippet['target'], 'source_title': remove_html(doc[2] or ''), 'source_link': f"https://scite.ai/reports/{doc[0]}" } if text in remove_html(doc[3]): if matched and remove_html(doc[3]).strip() != matched.strip(): continue new_text = text sent_loc = None sents = nltk.sent_tokenize(remove_html(doc[3])) for i, sent in enumerate(sents): if text in sent: new_text = sent sent_loc = i context = remove_html(doc[3]).replace('', '').replace('', '') if sent_loc: context_len = 3 sent_beg = sent_loc - context_len if sent_beg <= 0: sent_beg = 0 sent_end = sent_loc + context_len if sent_end >= len(sents): sent_end = len(sents) context = ''.join(sents[sent_beg:sent_end]) return { 'citation_statement': context, 'text': new_text, 'from': doc[0], 'supporting': doc[0], 'source_title': remove_html(doc[2] or ''), 'source_link': f"https://scite.ai/reports/{doc[0]}" } return None # @st.experimental_singleton # def init_models(): # nltk.download('stopwords') # nltk.download('punkt') # from nltk.corpus import stopwords # stop = set(stopwords.words('english') + list(string.punctuation)) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # question_answerer = pipeline( # "question-answering", model='nlpconnect/roberta-base-squad2-nq', # device=0 if torch.cuda.is_available() else -1, handle_impossible_answer=False, # ) # reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device) # # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1") # # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1") # return question_answerer, reranker, stop, device # qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer def clean_query(query, strict=True, clean=True): operator = ' ' if strict: operator = ' AND ' query = operator.join( [i for i in query.lower().split(' ') if clean and i not in stop]) if clean: query = query.translate(str.maketrans('', '', string.punctuation)) return query def card(title, context, score, link, supporting): st.markdown(f"""

{context} [Confidence: {score}%]
From {title}
""", unsafe_allow_html=True) html(f"""
""", width=None, height=42, scrolling=False) st.title("Scientific Question Answering with Citations") st.write(""" Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements. Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. For example try: Do tanning beds cause cancer? """) st.markdown(""" """, unsafe_allow_html=True) # with st.expander("Settings (strictness, context limit, top hits)"): # concat_passages = st.radio( # "Concatenate passages as one long context?", # ('yes', 'no')) # present_impossible = st.radio( # "Present impossible answers? (if the model thinks its impossible to answer should it still try?)", # ('yes', 'no')) # support_all = st.radio( # "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?", # ('no', 'yes')) # support_abstracts = st.radio( # "Use abstracts as a source document?", # ('yes', 'no', 'abstract only')) # strict_lenient_mix = st.radio( # "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out", # ('mix', 'fallback')) # confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1) # use_reranking = st.radio( # "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.", # ('yes', 'no')) # top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100) # context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25) # def paraphrase(text, max_length=128): # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) # generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length) # queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]) # preds = '\n * '.join(queries) # return preds def group_results_by_context(results): result_groups = {} for result in results: if result['context'] not in result_groups: result_groups[result['context']] = result result_groups[result['context']]['texts'] = [] result_groups[result['context']]['texts'].append( result['answer'] ) if result['score'] > result_groups[result['context']]['score']: result_groups[result['context']]['score'] = result['score'] return list(result_groups.values()) def matched_context(start_i, end_i, contexts_string, seperator='---'): # find seperators to identify start and end doc_starts = [0] for match in re.finditer(seperator, contexts_string): doc_starts.append(match.end()) for i in range(len(doc_starts)): if i == len(doc_starts) - 1: if start_i >= doc_starts[i]: return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '') if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]: return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '') return None # def run_query_full(query, progress_bar): # # if use_query_exp == 'yes': # # query_exp = paraphrase(f"question2question: {query}") # # st.markdown(f""" # # If you are not getting good results try one of: # # * {query_exp} # # """) # # could also try fallback if there are no good answers by score... # limit = top_hits_limit or 100 # context_limit = context_lim or 10 # contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only') # if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit: # contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only') # contexts = list( # set(contexts_strict + contexts_lenient) # ) # orig_docs = orig_docs_strict + orig_docs_lenient # elif strict_lenient_mix == 'mix': # contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False) # contexts = list( # set(contexts_strict + contexts_lenient) # ) # orig_docs = orig_docs_strict + orig_docs_lenient # else: # contexts = list( # set(contexts_strict) # ) # orig_docs = orig_docs_strict # progress_bar.progress(25) # if len(contexts) == 0 or not ''.join(contexts).strip(): # return st.markdown(""" #
#
#
# Sorry... no results for that question! Try another... #
#
#
# """, unsafe_allow_html=True) # if use_reranking == 'yes': # sentence_pairs = [[query, context] for context in contexts] # scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False) # hits = {contexts[idx]: scores[idx] for idx in range(len(scores))} # sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)] # contexts = sorted_contexts[:context_limit] # else: # contexts = contexts[:context_limit] # progress_bar.progress(50) # if concat_passages == 'yes': # context = '\n---'.join(contexts) # model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes') # else: # context = ['\n---\n'+ctx for ctx in contexts] # model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes') # results = [] # progress_bar.progress(75) # for i, result in enumerate(model_results): # if concat_passages == 'yes': # matched = matched_context(result['start'], result['end'], context) # else: # matched = matched_context(result['start'], result['end'], context[i]) # support = find_source(result['answer'], orig_docs, matched) # if not support: # continue # results.append({ # "answer": support['text'], # "title": support['source_title'], # "link": support['source_link'], # "context": support['citation_statement'], # "score": result['score'], # "doi": support["supporting"] # }) # grouped_results = group_results_by_context(results) # sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True) # if confidence_threshold == 0: # threshold = 0 # else: # threshold = (confidence_threshold or 10) / 100 # sorted_result = list(filter( # lambda x: x['score'] > threshold, # sorted_result # )) # progress_bar.progress(100) # for r in sorted_result: # ctx = remove_html(r["context"]) # for answer in r['texts']: # ctx = ctx.replace(answer.strip(), f"{answer.strip()}") # # .replace( '
Sorry... no results for that question! Try another...
""", unsafe_allow_html=True) for r in resp['results']: ctx = remove_html(r["context"]) for answer in r['texts']: ctx = ctx.replace(answer.strip(), f"{answer.strip()}") # .replace( '