import streamlit as st from transformers import pipeline import requests from bs4 import BeautifulSoup from nltk.corpus import stopwords import nltk import string from streamlit.components.v1 import html from sentence_transformers.cross_encoder import CrossEncoder as CE import numpy as np from typing import List, Tuple import torch class CrossEncoder: def __init__(self, model_path: str, **kwargs): self.model = CE(model_path, **kwargs) def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]: return self.model.predict( sentences=sentences, batch_size=batch_size, show_progress_bar=show_progress_bar) SCITE_API_KEY = st.secrets["SCITE_API_KEY"] def remove_html(x): soup = BeautifulSoup(x, 'html.parser') text = soup.get_text() return text def search(term, limit=10, clean=True, strict=True): term = clean_query(term, clean=clean, strict=strict) # heuristic, 2 searches strict and not? and then merge? search = f"https://api.scite.ai/search?mode=citations&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) return ( [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']], [(doc['doi'], doc['citations'], doc['title']) for doc in req.json()['hits']] ) def find_source(text, docs): for doc in docs: if text in remove_html(doc[1][0]['snippet']): new_text = text for snip in remove_html(doc[1][0]['snippet']).split('.'): if text in snip: new_text = snip return { 'citation_statement': doc[1][0]['snippet'].replace('', '').replace('', ''), 'text': new_text, 'from': doc[1][0]['source'], 'supporting': doc[1][0]['target'], 'source_title': doc[2], 'source_link': f"https://scite.ai/reports/{doc[0]}" } return None @st.experimental_singleton def init_models(): nltk.download('stopwords') stop = set(stopwords.words('english') + list(string.punctuation)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") question_answerer = pipeline( "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B', device=device ) reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device) return question_answerer, reranker, stop, device qa_model, reranker, stop, device = init_models() def clean_query(query, strict=True, clean=True): operator = ' ' if strict: operator = ' AND ' query = operator.join( [i for i in query.lower().split(' ') if clean and i not in stop]) if clean: query = query.translate(str.maketrans('', '', string.punctuation)) return query def card(title, context, score, link, supporting): st.markdown(f"""

{context} [Score: {score}]
From {title}
""", unsafe_allow_html=True) html(f"""
""", width=None, height=42, scrolling=False) st.title("Scientific Question Answering with Citations") st.write(""" Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements. Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. """) st.markdown(""" """, unsafe_allow_html=True) def run_query(query): if device == 'cpu': limit = 50 context_limit = 10 else: limit = 100 context_limit = 25 contexts, orig_docs = search(query, limit=limit) if len(contexts) == 0 or not ''.join(contexts).strip(): return st.markdown("""
Sorry... no results for that question! Try another...
""", unsafe_allow_html=True) sentence_pairs = [[query, context] for context in contexts] scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False) hits = {contexts[idx]: scores[idx] for idx in range(len(scores))} sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)] context = '\n'.join(sorted_contexts[:context_limit]) results = [] model_results = qa_model(question=query, context=context, top_k=10) for result in model_results: support = find_source(result['answer'], orig_docs) if not support: continue results.append({ "answer": support['text'], "title": support['source_title'], "link": support['source_link'], "context": support['citation_statement'], "score": result['score'], "doi": support["supporting"] }) sorted_result = sorted(results, key=lambda x: x['score'], reverse=True) sorted_result = list({ result['context']: result for result in sorted_result }.values()) sorted_result = sorted( sorted_result, key=lambda x: x['score'], reverse=True) for r in sorted_result: answer = r["answer"] ctx = remove_html(r["context"]).replace(answer, f"{answer}").replace( '