Spaces:
Sleeping
Sleeping
yinong333
commited on
Commit
•
7a8a241
0
Parent(s):
demo day code updates
Browse files- .chainlit/config.toml +84 -0
- Dockerfile +11 -0
- README.md +2 -0
- __pycache__/app.cpython-311.pyc +0 -0
- app.py +434 -0
- chainlit.md +14 -0
- prototype_mvp3.ipynb +0 -0
- requirements.txt +106 -0
.chainlit/config.toml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
# Whether to enable telemetry (default: true). No personal data is collected.
|
3 |
+
enable_telemetry = true
|
4 |
+
|
5 |
+
# List of environment variables to be provided by each user to use the app.
|
6 |
+
user_env = []
|
7 |
+
|
8 |
+
# Duration (in seconds) during which the session is saved when the connection is lost
|
9 |
+
session_timeout = 3600
|
10 |
+
|
11 |
+
# Enable third parties caching (e.g LangChain cache)
|
12 |
+
cache = false
|
13 |
+
|
14 |
+
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
|
15 |
+
# follow_symlink = false
|
16 |
+
|
17 |
+
[features]
|
18 |
+
# Show the prompt playground
|
19 |
+
prompt_playground = true
|
20 |
+
|
21 |
+
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
22 |
+
unsafe_allow_html = false
|
23 |
+
|
24 |
+
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
25 |
+
latex = false
|
26 |
+
|
27 |
+
# Authorize users to upload files with messages
|
28 |
+
multi_modal = true
|
29 |
+
|
30 |
+
# Allows user to use speech to text
|
31 |
+
[features.speech_to_text]
|
32 |
+
enabled = false
|
33 |
+
# See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
|
34 |
+
# language = "en-US"
|
35 |
+
|
36 |
+
[UI]
|
37 |
+
# Name of the app and chatbot.
|
38 |
+
name = "Chatbot"
|
39 |
+
|
40 |
+
# Show the readme while the conversation is empty.
|
41 |
+
show_readme_as_default = true
|
42 |
+
|
43 |
+
# Description of the app and chatbot. This is used for HTML tags.
|
44 |
+
# description = ""
|
45 |
+
|
46 |
+
# Large size content are by default collapsed for a cleaner ui
|
47 |
+
default_collapse_content = true
|
48 |
+
|
49 |
+
# The default value for the expand messages settings.
|
50 |
+
default_expand_messages = false
|
51 |
+
|
52 |
+
# Hide the chain of thought details from the user in the UI.
|
53 |
+
hide_cot = false
|
54 |
+
|
55 |
+
# Link to your github repo. This will add a github button in the UI's header.
|
56 |
+
# github = ""
|
57 |
+
|
58 |
+
# Specify a CSS file that can be used to customize the user interface.
|
59 |
+
# The CSS file can be served from the public directory or via an external link.
|
60 |
+
# custom_css = "/public/test.css"
|
61 |
+
|
62 |
+
# Override default MUI light theme. (Check theme.ts)
|
63 |
+
[UI.theme.light]
|
64 |
+
#background = "#FAFAFA"
|
65 |
+
#paper = "#FFFFFF"
|
66 |
+
|
67 |
+
[UI.theme.light.primary]
|
68 |
+
#main = "#F80061"
|
69 |
+
#dark = "#980039"
|
70 |
+
#light = "#FFE7EB"
|
71 |
+
|
72 |
+
# Override default MUI dark theme. (Check theme.ts)
|
73 |
+
[UI.theme.dark]
|
74 |
+
#background = "#FAFAFA"
|
75 |
+
#paper = "#FFFFFF"
|
76 |
+
|
77 |
+
[UI.theme.dark.primary]
|
78 |
+
#main = "#F80061"
|
79 |
+
#dark = "#980039"
|
80 |
+
#light = "#FFE7EB"
|
81 |
+
|
82 |
+
|
83 |
+
[meta]
|
84 |
+
generated_by = "0.7.700"
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
RUN useradd -m -u 1000 user
|
3 |
+
USER user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME/app
|
7 |
+
COPY --chown=user . $HOME/app
|
8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
COPY . .
|
11 |
+
CMD ["chainlit", "run", "app.py", "--port", "7860"]
|
README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# AIE4 Final Demo Day App
|
2 |
+
Literature Review App
|
__pycache__/app.cpython-311.pyc
ADDED
Binary file (18.8 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainlit as cl
|
2 |
+
from Bio import Entrez
|
3 |
+
from langchain.tools import StructuredTool
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from langgraph.graph.message import add_messages
|
7 |
+
from langgraph.prebuilt import ToolNode
|
8 |
+
from langgraph.graph import StateGraph, END
|
9 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
+
|
11 |
+
from IPython.display import display, Markdown
|
12 |
+
from sentence_transformers import SentenceTransformer, util
|
13 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
14 |
+
from langchain.tools import StructuredTool
|
15 |
+
from langchain.agents import initialize_agent, Tool, AgentType
|
16 |
+
from langchain_openai import ChatOpenAI
|
17 |
+
from langgraph.graph.message import add_messages
|
18 |
+
from typing import List, TypedDict, Annotated
|
19 |
+
import xml.etree.ElementTree as ET
|
20 |
+
import uuid
|
21 |
+
import re
|
22 |
+
from langchain_qdrant import QdrantVectorStore
|
23 |
+
from qdrant_client import QdrantClient
|
24 |
+
from qdrant_client.http.models import Distance, VectorParams
|
25 |
+
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
|
26 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
27 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
28 |
+
from langchain.chains import (
|
29 |
+
ConversationalRetrievalChain,
|
30 |
+
)
|
31 |
+
from langchain.docstore.document import Document
|
32 |
+
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
33 |
+
from transformers import GPT2Tokenizer
|
34 |
+
|
35 |
+
# Load the pre-trained model for embeddings (you can choose a different model if preferred)
|
36 |
+
semantic_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
37 |
+
|
38 |
+
def pretty_print(message: str) -> None:
|
39 |
+
display(Markdown(f"```markdown\n{message}\n```"))
|
40 |
+
|
41 |
+
# Set your Entrez email for PubMed queries
|
42 |
+
Entrez.email = "[email protected]"
|
43 |
+
|
44 |
+
# 1. Define PubMed Search Tool
|
45 |
+
class PubMedSearchInput(BaseModel):
|
46 |
+
query: str
|
47 |
+
#max_results: int = 5
|
48 |
+
|
49 |
+
# PubMed search tool using Entrez (now with structured inputs)
|
50 |
+
def pubmed_search(query: str, max_results: int = 3):
|
51 |
+
"""Search PubMed using Entrez API and return abstracts."""
|
52 |
+
handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
|
53 |
+
record = Entrez.read(handle)
|
54 |
+
handle.close()
|
55 |
+
pmids = record["IdList"]
|
56 |
+
|
57 |
+
# Fetch abstracts
|
58 |
+
handle = Entrez.efetch(db="pubmed", id=",".join(pmids), retmode="xml")
|
59 |
+
records = Entrez.read(handle)
|
60 |
+
handle.close()
|
61 |
+
|
62 |
+
abstracts = []
|
63 |
+
for record in records['PubmedArticle']:
|
64 |
+
try:
|
65 |
+
title = record['MedlineCitation']['Article']['ArticleTitle']
|
66 |
+
abstract = record['MedlineCitation']['Article']['Abstract']['AbstractText'][0]
|
67 |
+
pmid = record['MedlineCitation']['PMID']
|
68 |
+
abstracts.append({"PMID": pmid, "Title": title, "Abstract": abstract})
|
69 |
+
except KeyError:
|
70 |
+
pass
|
71 |
+
return abstracts
|
72 |
+
|
73 |
+
# Define the AbstractScreeningInput using Pydantic BaseModel
|
74 |
+
class AbstractScreeningInput(BaseModel):
|
75 |
+
abstracts: List[dict]
|
76 |
+
criteria: str
|
77 |
+
|
78 |
+
def screen_abstracts_semantic(abstracts: List[dict], criteria: str, similarity_threshold: float = 0.4):
|
79 |
+
"""Screen abstracts based on semantic similarity to the criteria."""
|
80 |
+
|
81 |
+
# Compute the embedding of the criteria
|
82 |
+
criteria_embedding = semantic_model.encode(criteria, convert_to_tensor=True)
|
83 |
+
|
84 |
+
screened = []
|
85 |
+
for paper in abstracts:
|
86 |
+
abstract_text = paper['Abstract']
|
87 |
+
|
88 |
+
# Compute the embedding of the abstract
|
89 |
+
abstract_embedding = semantic_model.encode(abstract_text, convert_to_tensor=True)
|
90 |
+
|
91 |
+
# Compute cosine similarity between the abstract and the criteria
|
92 |
+
similarity_score = util.cos_sim(abstract_embedding, criteria_embedding).item()
|
93 |
+
|
94 |
+
if similarity_score >= similarity_threshold:
|
95 |
+
screened.append({
|
96 |
+
"PMID": paper['PMID'],
|
97 |
+
"Decision": "Include",
|
98 |
+
"Reason": f"Similarity score {similarity_score:.2f} >= threshold {similarity_threshold}"
|
99 |
+
})
|
100 |
+
else:
|
101 |
+
screened.append({
|
102 |
+
"PMID": paper['PMID'],
|
103 |
+
"Decision": "Exclude",
|
104 |
+
"Reason": f"Similarity score {similarity_score:.2f} < threshold {similarity_threshold}"
|
105 |
+
})
|
106 |
+
|
107 |
+
return screened
|
108 |
+
|
109 |
+
# Define the PubMed Search Tool as a StructuredTool with proper input schema
|
110 |
+
pubmed_tool = StructuredTool(
|
111 |
+
name="PubMed_Search_Tool",
|
112 |
+
func=pubmed_search,
|
113 |
+
description="Search PubMed for research papers and retrieve abstracts. Pass the abstracts (returned results) to another tool.",
|
114 |
+
args_schema=PubMedSearchInput # Use Pydantic BaseModel for schema
|
115 |
+
)
|
116 |
+
|
117 |
+
# Define the Abstract Screening Tool with semantic screening
|
118 |
+
semantic_screening_tool = StructuredTool(
|
119 |
+
name="Semantic_Abstract_Screening_Tool",
|
120 |
+
func=screen_abstracts_semantic,
|
121 |
+
description="""Screen PubMed abstracts based on semantic similarity to inclusion/exclusion criteria. Uses cosine similarity between abstracts and criteria. Requires 'abstracts' and 'screening criteria' as input.
|
122 |
+
The 'abstracts' is a list of dictionary with keys as PMID, Title, Abstract.
|
123 |
+
Output a similarity scores for each abstract and send the list of pmids that passed the screening to Fetch_Extract_Tool.""",
|
124 |
+
args_schema=AbstractScreeningInput # Pydantic schema remains the same
|
125 |
+
)
|
126 |
+
|
127 |
+
# 3. Define Full-Text Retrieval Tool
|
128 |
+
class FetchExtractInput(BaseModel):
|
129 |
+
pmids: List[str] # List of PubMed IDs to fetch full text for
|
130 |
+
query: str
|
131 |
+
|
132 |
+
def extract_text_from_pmc_xml(xml_content: str) -> str:
|
133 |
+
"""a function to format and clean text from PMC full-text XML."""
|
134 |
+
try:
|
135 |
+
root = ET.fromstring(xml_content)
|
136 |
+
|
137 |
+
# Find all relevant text sections (e.g., <body>, <sec>, <p>)
|
138 |
+
body_text = []
|
139 |
+
for elem in root.iter():
|
140 |
+
if elem.tag in ['p', 'sec', 'title', 'abstract', 'body']: # Add more tags as needed
|
141 |
+
if elem.text:
|
142 |
+
body_text.append(elem.text.strip())
|
143 |
+
|
144 |
+
# Join all the text elements to form the complete full text
|
145 |
+
full_text = "\n\n".join(body_text)
|
146 |
+
|
147 |
+
return full_text
|
148 |
+
except ET.ParseError:
|
149 |
+
print("Error parsing XML content.")
|
150 |
+
return ""
|
151 |
+
|
152 |
+
def fetch_and_extract(pmids: List[str], query: str):
|
153 |
+
"""Fetch full text from PubMed Central for given PMIDs, split into chunks,
|
154 |
+
store in a Qdrant vector database, and perform RAG for each paper.
|
155 |
+
Retrieves exactly 3 chunks per paper (if available) and generates a consolidated answer for each paper.
|
156 |
+
"""
|
157 |
+
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
158 |
+
corpus = {}
|
159 |
+
consolidated_results={}
|
160 |
+
|
161 |
+
# Fetch records from PubMed Central (PMC)
|
162 |
+
handle = Entrez.efetch(db="pubmed", id=",".join(pmids), retmode="xml")
|
163 |
+
records = Entrez.read(handle)
|
164 |
+
handle.close()
|
165 |
+
|
166 |
+
full_articles = []
|
167 |
+
for record in records['PubmedArticle']:
|
168 |
+
try:
|
169 |
+
title = record['MedlineCitation']['Article']['ArticleTitle']
|
170 |
+
pmid = record['MedlineCitation']['PMID']
|
171 |
+
pmc_id = 'nan'
|
172 |
+
pmc_id_temp = record['PubmedData']['ArticleIdList']
|
173 |
+
|
174 |
+
# Extract PMC ID if available
|
175 |
+
for ele in pmc_id_temp:
|
176 |
+
if ele.attributes['IdType'] == 'pmc':
|
177 |
+
pmc_id = ele.replace('PMC', '')
|
178 |
+
break
|
179 |
+
|
180 |
+
# Fetch full article from PMC
|
181 |
+
if pmc_id != 'nan':
|
182 |
+
handle = Entrez.efetch(db="pmc", id=pmc_id, rettype="full", retmode="xml")
|
183 |
+
full_article = handle.read()
|
184 |
+
handle.close()
|
185 |
+
|
186 |
+
# Split the full article into chunks
|
187 |
+
cleaned_full_article = extract_text_from_pmc_xml(full_article)
|
188 |
+
full_articles.append({
|
189 |
+
"PMID": pmid,
|
190 |
+
"Title": title,
|
191 |
+
"FullText": cleaned_full_article # Add chunked text
|
192 |
+
})
|
193 |
+
else:
|
194 |
+
full_articles.append({"PMID": pmid, "Title": title, "FullText": "cannot fetch"})
|
195 |
+
except KeyError:
|
196 |
+
pass
|
197 |
+
|
198 |
+
# Create corpus for each chunk
|
199 |
+
for article in full_articles:
|
200 |
+
article_id = str(uuid.uuid4())
|
201 |
+
corpus[article_id] = {
|
202 |
+
"page_content": article["FullText"],
|
203 |
+
"metadata": {
|
204 |
+
"PMID": article["PMID"],
|
205 |
+
"Title": article["Title"]
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
documents = [
|
210 |
+
Document(page_content=content["page_content"], metadata=content["metadata"])
|
211 |
+
for content in corpus.values()
|
212 |
+
]
|
213 |
+
CHUNK_SIZE = 1000
|
214 |
+
CHUNK_OVERLAP = 200
|
215 |
+
|
216 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
217 |
+
chunk_size=CHUNK_SIZE,
|
218 |
+
chunk_overlap=CHUNK_OVERLAP,
|
219 |
+
length_function=len,
|
220 |
+
)
|
221 |
+
|
222 |
+
split_chunks = text_splitter.split_documents(documents)
|
223 |
+
|
224 |
+
id_set = set()
|
225 |
+
for document in split_chunks:
|
226 |
+
id = str(uuid.uuid4())
|
227 |
+
while id in id_set:
|
228 |
+
id = uuid.uuid4()
|
229 |
+
id_set.add(id)
|
230 |
+
document.metadata["uuid"] = id
|
231 |
+
|
232 |
+
LOCATION = ":memory:"
|
233 |
+
COLLECTION_NAME = "pmd_data"
|
234 |
+
VECTOR_SIZE = 384
|
235 |
+
|
236 |
+
# Initialize Qdrant client
|
237 |
+
qdrant_client = QdrantClient(location=LOCATION)
|
238 |
+
|
239 |
+
# Create a collection in Qdrant
|
240 |
+
qdrant_client.create_collection(
|
241 |
+
collection_name=COLLECTION_NAME,
|
242 |
+
vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
|
243 |
+
)
|
244 |
+
|
245 |
+
# Initialize the Qdrant vector store without the embedding argument
|
246 |
+
vdb = QdrantVectorStore(
|
247 |
+
client=qdrant_client,
|
248 |
+
collection_name=COLLECTION_NAME,
|
249 |
+
embedding=embedding_model,
|
250 |
+
)
|
251 |
+
|
252 |
+
# Add embedded documents to Qdrant
|
253 |
+
vdb.add_documents(split_chunks)
|
254 |
+
|
255 |
+
# Query for each paper and consolidate answers
|
256 |
+
for pmid in pmids:
|
257 |
+
# Correctly structure the filter using Qdrant Filter model
|
258 |
+
qdrant_filter = Filter(
|
259 |
+
must=[
|
260 |
+
FieldCondition(key="metadata.PMID", match=MatchValue(value=pmid))
|
261 |
+
]
|
262 |
+
)
|
263 |
+
|
264 |
+
# Custom filtering for the retriever to only fetch chunks related to the current PMID
|
265 |
+
retriever_with_filter = vdb.as_retriever(
|
266 |
+
search_kwargs={
|
267 |
+
"filter": qdrant_filter, # Correctly passing the Qdrant filter
|
268 |
+
"k": 3 # Retrieve 3 chunks per PMID
|
269 |
+
}
|
270 |
+
)
|
271 |
+
|
272 |
+
# Reset message history and memory for each query to avoid interference
|
273 |
+
message_history = ChatMessageHistory()
|
274 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", chat_memory=message_history, return_messages=True)
|
275 |
+
|
276 |
+
# Create the ConversationalRetrievalChain with the filtered retriever
|
277 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
278 |
+
ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
|
279 |
+
retriever=retriever_with_filter,
|
280 |
+
memory=memory,
|
281 |
+
return_source_documents=True
|
282 |
+
)
|
283 |
+
|
284 |
+
# Query the vector store for relevant documents and extract information
|
285 |
+
result = qa_chain({"question": query})
|
286 |
+
|
287 |
+
# Generate the final answer based on the retrieved chunks
|
288 |
+
generated_answer = result["answer"] # This contains the LLM's generated answer based on the retrieved chunks
|
289 |
+
generated_source = result["source_documents"]
|
290 |
+
|
291 |
+
# Consolidate the results for each paper
|
292 |
+
paper_info = {
|
293 |
+
"PMID": pmid,
|
294 |
+
"Title": result["source_documents"][0].metadata["Title"] if result["source_documents"] else "Unknown Title",
|
295 |
+
"Generated Answer": generated_answer, # Store the generated answer,
|
296 |
+
"Sources": generated_source
|
297 |
+
}
|
298 |
+
|
299 |
+
consolidated_results[pmid] = paper_info
|
300 |
+
|
301 |
+
# Return consolidated results for all papers
|
302 |
+
return consolidated_results
|
303 |
+
|
304 |
+
rag_tool = StructuredTool(
|
305 |
+
name="Fetch_Extract_Tool",
|
306 |
+
func=fetch_and_extract,
|
307 |
+
description="""Fetch full-text articles based on PMIDs and store them in a Qdrant vector database.
|
308 |
+
Then extract information based on user's query via Qdrant retriever using a RAG pipeline.
|
309 |
+
Requires list of PMIDs and user query as input.""",
|
310 |
+
args_schema=FetchExtractInput
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
tool_belt = [
|
315 |
+
pubmed_tool,
|
316 |
+
semantic_screening_tool,
|
317 |
+
rag_tool
|
318 |
+
]
|
319 |
+
|
320 |
+
|
321 |
+
# Model setup with tools bound
|
322 |
+
model = ChatOpenAI(model="gpt-4o", temperature=0)
|
323 |
+
model = model.bind_tools(tool_belt)
|
324 |
+
|
325 |
+
# Agent state to handle the messages
|
326 |
+
class AgentState(dict):
|
327 |
+
messages: Annotated[list, add_messages]
|
328 |
+
cycle_count: int # Add a counter to track the number of cycles
|
329 |
+
|
330 |
+
# Function to call the model and handle the flow automatically
|
331 |
+
def call_model(state):
|
332 |
+
messages = state["messages"]
|
333 |
+
response = model.invoke(messages)
|
334 |
+
return {"messages": [response], "cycle_count": state["cycle_count"] + 1} # Increment cycle count
|
335 |
+
|
336 |
+
tool_node = ToolNode(tool_belt)
|
337 |
+
|
338 |
+
# Create the state graph for managing the flow between the agent and tools
|
339 |
+
uncompiled_graph = StateGraph(AgentState)
|
340 |
+
uncompiled_graph.add_node("agent", call_model)
|
341 |
+
uncompiled_graph.add_node("action", tool_node)
|
342 |
+
|
343 |
+
# Set the entry point for the graph
|
344 |
+
uncompiled_graph.set_entry_point("agent")
|
345 |
+
|
346 |
+
# Define a function to check if the process should continue
|
347 |
+
def should_continue(state):
|
348 |
+
# Check if the cycle count exceeds the limit (e.g., 10)
|
349 |
+
if state["cycle_count"] > 20:
|
350 |
+
print(f"Reached the cycle limit of {state['cycle_count']} cycles. Ending the process.")
|
351 |
+
return END
|
352 |
+
|
353 |
+
# If there are tool calls, continue to the action node
|
354 |
+
last_message = state["messages"][-1]
|
355 |
+
if last_message.tool_calls:
|
356 |
+
return "action"
|
357 |
+
|
358 |
+
return END
|
359 |
+
|
360 |
+
# Add conditional edges for the agent to action
|
361 |
+
uncompiled_graph.add_conditional_edges("agent", should_continue)
|
362 |
+
uncompiled_graph.add_edge("action", "agent")
|
363 |
+
|
364 |
+
# Compile the state graph
|
365 |
+
compiled_graph = uncompiled_graph.compile()
|
366 |
+
|
367 |
+
# Function to run the compiled graph asynchronously
|
368 |
+
async def run_graph(inputs):
|
369 |
+
final_message_content = None # Variable to store the final message content
|
370 |
+
|
371 |
+
async for chunk in compiled_graph.astream(inputs, stream_mode="updates"):
|
372 |
+
for node, values in chunk.items():
|
373 |
+
print(values["messages"])
|
374 |
+
|
375 |
+
# Check if the message contains content
|
376 |
+
if "messages" in values and values["messages"]:
|
377 |
+
final_message = values["messages"][-1]
|
378 |
+
if hasattr(final_message, 'content'):
|
379 |
+
final_message_content = final_message.content
|
380 |
+
|
381 |
+
print("\n\n")
|
382 |
+
|
383 |
+
if final_message_content:
|
384 |
+
print("Final message content from the last chunk:")
|
385 |
+
print(final_message_content)
|
386 |
+
|
387 |
+
return final_message_content
|
388 |
+
|
389 |
+
# Chainlit interaction setup
|
390 |
+
@cl.on_chat_start
|
391 |
+
async def on_chat_start():
|
392 |
+
await cl.Message(content="Welcome! Please provide your PubMed query and screening criteria.").send()
|
393 |
+
|
394 |
+
@cl.on_message
|
395 |
+
async def main(message):
|
396 |
+
# Extract query and screening criteria from the user's message
|
397 |
+
user_input = message.content
|
398 |
+
|
399 |
+
# Build inputs for the agent
|
400 |
+
# system_instructions = SystemMessage(content="""
|
401 |
+
# 1. Use the PubMed search tool to search for papers.
|
402 |
+
# 2. Retrieve the abstracts from the search results.
|
403 |
+
# 3. Screen the abstracts based on the criteria provided by the user. If error happens,retry by feeding in both 'abstracts' and 'screening criteria' as input.
|
404 |
+
# The 'abstracts' is a list of dictionary with keys as PMID, Title, Abstract (which is extracted from preivous step). For the decisions of include and exclude, give me the similarity score you calculated.
|
405 |
+
# 4. Please provide a full summary at the end of the entire flow executed, detailing the whole process/reasoning for each paper.
|
406 |
+
# The user will provide the search query and screening criteria.
|
407 |
+
# Make sure you finish everything in one step before moving on to next step.
|
408 |
+
# Do not call more than one tool in one action.""")
|
409 |
+
|
410 |
+
system_instructions = SystemMessage(content="""Please execute the following steps in sequence:
|
411 |
+
1. Use the PubMed search tool to search for papers.
|
412 |
+
2. Retrieve the abstracts from the search results.
|
413 |
+
3. Screen the abstracts based on the criteria provided by the user.
|
414 |
+
4. Fetch full-text articles for all the papers that pass step 3. Store the full-text articles in the Qdrant vector database,
|
415 |
+
and extract the requested information for each article that passed step 3 from the full-text using the query provided by the user.
|
416 |
+
5. Please provide a full summary at the end of the entire flow executed, detailing each paper's title, PMID, and the whole process/screening/reasoning for each paper.
|
417 |
+
The user will provide the search query, screening criteria, and the query for information extraction.
|
418 |
+
Make sure you finish everything in one step before moving on to next step.
|
419 |
+
Do not call more than one tool in one action.""")
|
420 |
+
human_inputs = HumanMessage(content=user_input)
|
421 |
+
|
422 |
+
inputs = {
|
423 |
+
"messages": [system_instructions, human_inputs],
|
424 |
+
"cycle_count": 0,
|
425 |
+
}
|
426 |
+
|
427 |
+
# Run the agent flow and capture the response
|
428 |
+
response = await run_graph(inputs)
|
429 |
+
|
430 |
+
# Display the response in the Chainlit UI
|
431 |
+
if response:
|
432 |
+
await cl.Message(content=response).send()
|
433 |
+
else:
|
434 |
+
await cl.Message(content="Sorry, I couldn't process the request.").send()
|
chainlit.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to Chainlit! 🚀🤖
|
2 |
+
|
3 |
+
Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
|
4 |
+
|
5 |
+
## Useful Links 🔗
|
6 |
+
|
7 |
+
- **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
|
8 |
+
- **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
|
9 |
+
|
10 |
+
We can't wait to see what you create with Chainlit! Happy coding! 💻😊
|
11 |
+
|
12 |
+
## Welcome screen
|
13 |
+
|
14 |
+
To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
|
prototype_mvp3.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohappyeyeballs==2.4.3
|
3 |
+
aiohttp==3.10.8
|
4 |
+
aiosignal==1.3.1
|
5 |
+
annotated-types==0.7.0
|
6 |
+
anyio==3.7.1
|
7 |
+
async-timeout==4.0.3
|
8 |
+
asyncer==0.0.2
|
9 |
+
attrs==24.2.0
|
10 |
+
bidict==0.23.1
|
11 |
+
certifi==2024.8.30
|
12 |
+
chainlit==0.7.700
|
13 |
+
charset-normalizer==3.3.2
|
14 |
+
click==8.1.7
|
15 |
+
dataclasses-json==0.5.14
|
16 |
+
Deprecated==1.2.14
|
17 |
+
distro==1.9.0
|
18 |
+
exceptiongroup==1.2.2
|
19 |
+
fastapi==0.100.1
|
20 |
+
fastapi-socketio==0.0.10
|
21 |
+
filetype==1.2.0
|
22 |
+
frozenlist==1.4.1
|
23 |
+
googleapis-common-protos==1.65.0
|
24 |
+
greenlet==3.1.1
|
25 |
+
grpcio==1.66.2
|
26 |
+
grpcio-tools==1.62.3
|
27 |
+
h11==0.14.0
|
28 |
+
h2==4.1.0
|
29 |
+
hpack==4.0.0
|
30 |
+
httpcore==0.17.3
|
31 |
+
httpx==0.24.1
|
32 |
+
hyperframe==6.0.1
|
33 |
+
idna==3.10
|
34 |
+
importlib_metadata==8.4.0
|
35 |
+
jiter==0.5.0
|
36 |
+
jsonpatch==1.33
|
37 |
+
jsonpointer==3.0.0
|
38 |
+
langchain==0.2.16
|
39 |
+
langchain-community==0.2.16
|
40 |
+
langchain-core==0.2.38
|
41 |
+
langchain-openai==0.1.23
|
42 |
+
langchain-qdrant==0.1.4
|
43 |
+
langgraph==0.2.19
|
44 |
+
langchain-huggingface==0.0.3
|
45 |
+
langchain-text-splitters==0.2.4
|
46 |
+
langsmith==0.1.121
|
47 |
+
Lazify==0.4.0
|
48 |
+
marshmallow==3.22.0
|
49 |
+
multidict==6.1.0
|
50 |
+
mypy-extensions==1.0.0
|
51 |
+
nest-asyncio==1.6.0
|
52 |
+
numpy==1.26.4
|
53 |
+
openai==1.44.0
|
54 |
+
opentelemetry-api==1.27.0
|
55 |
+
opentelemetry-exporter-otlp==1.27.0
|
56 |
+
opentelemetry-exporter-otlp-proto-common==1.27.0
|
57 |
+
opentelemetry-exporter-otlp-proto-grpc==1.27.0
|
58 |
+
opentelemetry-exporter-otlp-proto-http==1.27.0
|
59 |
+
opentelemetry-instrumentation==0.48b0
|
60 |
+
opentelemetry-proto==1.27.0
|
61 |
+
opentelemetry-sdk==1.27.0
|
62 |
+
opentelemetry-semantic-conventions==0.48b0
|
63 |
+
orjson==3.10.7
|
64 |
+
packaging==23.2
|
65 |
+
portalocker==2.10.1
|
66 |
+
protobuf==4.25.5
|
67 |
+
pydantic==2.9.0
|
68 |
+
pydantic-settings==2.5.2
|
69 |
+
pydantic_core==2.23.2
|
70 |
+
PyJWT==2.9.0
|
71 |
+
PyMuPDF==1.24.10
|
72 |
+
PyMuPDFb==1.24.10
|
73 |
+
python-dotenv==1.0.1
|
74 |
+
python-engineio==4.9.1
|
75 |
+
python-graphql-client==0.4.3
|
76 |
+
python-multipart==0.0.6
|
77 |
+
python-socketio==5.11.4
|
78 |
+
PyYAML==6.0.2
|
79 |
+
qdrant-client==1.11.3
|
80 |
+
regex==2024.9.11
|
81 |
+
requests==2.32.3
|
82 |
+
simple-websocket==1.0.0
|
83 |
+
sniffio==1.3.1
|
84 |
+
SQLAlchemy==2.0.35
|
85 |
+
starlette==0.27.0
|
86 |
+
syncer==2.0.3
|
87 |
+
tenacity==8.5.0
|
88 |
+
tiktoken==0.7.0
|
89 |
+
tomli==2.0.1
|
90 |
+
tqdm==4.66.5
|
91 |
+
typing-inspect==0.9.0
|
92 |
+
typing_extensions==4.12.2
|
93 |
+
uptrace==1.26.0
|
94 |
+
urllib3==2.2.3
|
95 |
+
uvicorn==0.23.2
|
96 |
+
watchfiles==0.20.0
|
97 |
+
websockets==13.1
|
98 |
+
wrapt==1.16.0
|
99 |
+
wsproto==1.2.0
|
100 |
+
yarl==1.13.1
|
101 |
+
zipp==3.20.2
|
102 |
+
Bio==1.84
|
103 |
+
unstructured==0.15.7
|
104 |
+
python-pptx==1.0.2
|
105 |
+
nltk==3.9.1
|
106 |
+
sentence-transformers==3.1.1
|