|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.chains import RetrievalQA |
|
from langchain.schema import Document |
|
import json |
|
import re |
|
from tqdm import tqdm |
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-small-v0.1") |
|
chroma_dir = "./chroma_trials_data_version_3" |
|
chroma_db = Chroma( |
|
embedding_function = embedding_model, |
|
persist_directory=chroma_dir |
|
) |
|
|
|
def normalize_clinical_trial_data(data): |
|
|
|
identification = data["protocolSection"]["identificationModule"] |
|
description = data["protocolSection"]["descriptionModule"] |
|
eligibility = data["protocolSection"]["eligibilityModule"] |
|
locations = data["protocolSection"]["contactsLocationsModule"]["locations"][0] |
|
inclusions_exclusions = eligibility.get("eligibilityCriteria", "").split("Exclusion Criteria:") |
|
inclusions = "" |
|
exclusions = "" |
|
if len(inclusions_exclusions) >1: |
|
exclusions = inclusions_exclusions[1] |
|
inclusions = inclusions_exclusions[0].split("Inclusion Criteria:") |
|
if len(inclusions)>0: |
|
inclusions = inclusions[-1] |
|
|
|
|
|
normalized_data = { |
|
"title": identification.get("officialTitle", ""), |
|
"summary": description.get("briefSummary", ""), |
|
"min_age": eligibility.get("minimumAge", ""), |
|
"max_age": eligibility.get("maximumAge", ""), |
|
"gender": eligibility.get("sex", ""), |
|
"inclusions": inclusions, |
|
"exclusions": exclusions, |
|
"facility": locations.get("facility", ""), |
|
"status": locations.get("status", ""), |
|
"city": locations.get("city", ""), |
|
"state": locations.get("state", ""), |
|
"country": locations.get("country", ""), |
|
"contacts": "\n".join([ |
|
f'Name: {contact.get("name", "")}, Role: {contact.get("role", "")}, Phone: {contact.get("phone", "")}, Email: {contact.get("email", "")}' for contact in locations.get("contacts", []) |
|
]) |
|
|
|
} |
|
|
|
return normalized_data |
|
|
|
def store_data_in_chroma(raw_data): |
|
documents = [] |
|
count = 0 |
|
for record in tqdm(raw_data): |
|
try: |
|
normalized_record = normalize_clinical_trial_data(record) |
|
content = f"""Title: {normalized_record['title']} |
|
Summary: {normalized_record['summary']} |
|
Inclusions: {normalized_record['inclusions']} |
|
Exclusions: {normalized_record['exclusions']} |
|
Contacts: {normalized_record['contacts']} |
|
Acceptable Age Range: {normalized_record['min_age']}- {normalized_record['max_age']} |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata = { |
|
"facility": normalized_record['facility'], |
|
"status": normalized_record['status'], |
|
"city": normalized_record['city'], |
|
"state": normalized_record['state'], |
|
"country": normalized_record['country'] |
|
} |
|
documents.append( |
|
Document( |
|
page_content=content, metadata=metadata |
|
) |
|
) |
|
count+=1 |
|
if count > 500: |
|
break |
|
except Exception as e: |
|
print(e) |
|
print("Document_size", len(documents)) |
|
chroma_db.add_documents(documents) |
|
chroma_db.persist() |
|
print('Data store in ChormaDB Successfully') |
|
|
|
|
|
def get_unique_city_state_country(): |
|
results = chroma_db._collection.get(include=["metadatas"]) |
|
metadata = {'city': [], 'state':[], 'country':[]} |
|
for doc in results['metadatas']: |
|
metadata['city'].append(doc['city']) |
|
metadata['state'].append(doc['state']) |
|
metadata['country'].append(doc['country']) |
|
|
|
return list(set(metadata['city'])),list(set(metadata['state'])), list(set(metadata['country'])) |
|
|
|
|
|
with open("ctg-studies.json", "r") as d: |
|
raw_data = json.load(d) |
|
store_data_in_chroma(raw_data) |
|
|
|
city, state, country = get_unique_city_state_country() |
|
|
|
model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
metadata_extraction_system_prompt = f"""You are an advanced information extractor. |
|
Your task is to analyze the user input and extract location details (city, state, or country). |
|
Ground_Data: |
|
1. city: {city}, |
|
2. state: {state}, |
|
3. country: {country} |
|
|
|
Instructions: |
|
1. If a specific location type (e.g., city, state, or country) is not mentioned in the input, set its value to "". |
|
2. Always include the keys "city", "state", and "country" in the output. |
|
3. Match the values of "city", "state", and "country" strictly with the corresponding categories in the Ground_Data: |
|
- Assign a value to "city" only if it matches any entry in Ground_Data["city"]. |
|
- Assign a value to "state" only if it matches any entry in Ground_Data["state"]. |
|
- Assign a value to "country" only if it matches any entry in Ground_Data["country"]. |
|
4. If a term does not match any entry in the Ground_Data for its category, leave it as "". |
|
5. Do not make assumptions or infer any details not explicitly stated in the input. |
|
|
|
For Example: |
|
Input: "Changchun City, China" |
|
Output: {{"city": "Changchun", "state": "", "country": "China"}} |
|
Reason: state is not available |
|
|
|
Input: "What do you think about United States?" |
|
Output: {{"city": "", "state": "", "country": "United States"}} |
|
Wrong Output: {{"city": "", "state": "United States", "country": ""}} |
|
Reason: United States is not available in state Ground_Data |
|
|
|
Your response MUST be directly parsed using **json.loads** nothing else. |
|
""" |
|
|
|
final_response_system_prompt = """You are an AI assistant specialized in providing information about ongoing clinical trials. |
|
You will assist users by extracting relevant details from the provided clinical trial documents. |
|
|
|
Key Instructions: |
|
1. Use only the information explicitly stated in the documents. |
|
2. Do not rely on general knowledge or assumptions. |
|
3. If the user requests information not covered by the documents, ask clarifying questions or inform them that the required data is not available. |
|
4. When presenting information, include specific details like trial titles, eligibility criteria, contact details, and locations, ensuring they align with the users query. |
|
5. Keep responses concise and tailored to the users request. |
|
6. Avoid speculation or providing unrelated information. |
|
|
|
Available Information: |
|
1. Clinical trial details, including titles, summaries, eligibility criteria, exclusions, and contact information. |
|
2. Contacts for trial coordination, including their roles, phone numbers, and emails. |
|
Use these documents as your sole source of truth to address user queries. |
|
""" |
|
|
|
fallback_system_prompt = f""" |
|
You are an AI assistant specialized in providing information about ongoing clinical trials. |
|
You will assist users by extracting relevant details from the provided clinical trial documents. |
|
|
|
If the documents are empty and user has any city, state or country specified in question |
|
ask for some verification questions based on location |
|
Ground_Data: |
|
1. city: {city}, |
|
2. state: {state}, |
|
3. country: {country} |
|
""" |
|
|
|
|
|
def generate_llm_response(messages): |
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
generated_ids = model.generate( |
|
**model_inputs, |
|
max_new_tokens=512, |
|
do_sample=False |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
return response |
|
|
|
def query_chroma_dynamic(question): |
|
search_kwargs = {"k": 5} |
|
extracted_metadata = {} |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": metadata_extraction_system_prompt |
|
}, |
|
{"role": "user", "content": f"{question}"} |
|
] |
|
llm_response = generate_llm_response(messages) |
|
print("Extraction LLM Response: ", llm_response) |
|
try: |
|
extracted_metadata = json.loads(llm_response.replace('`','').replace('json','')) |
|
except Exception as e: |
|
print(e, llm_response, type(llm_response)) |
|
|
|
if len(extracted_metadata) > 0: |
|
cleaned_data = [] |
|
for k, v in extracted_metadata.items(): |
|
if len(v)>0: |
|
cleaned_data.append({k:v}) |
|
if len(cleaned_data) > 1: |
|
search_kwargs["filter"] = {"$and": cleaned_data} |
|
elif len(cleaned_data) == 1: |
|
search_kwargs["filter"] = cleaned_data[0] |
|
else: |
|
cleaned_data = extracted_metadata.copy() |
|
retriever = chroma_db.as_retriever( |
|
search_kwargs=search_kwargs |
|
) |
|
retrieved_results = retriever.get_relevant_documents(question) |
|
if retrieved_results == 0: |
|
retriever = chroma_db.as_retriever( |
|
search_kwargs={"k":5} |
|
) |
|
retrieved_results = retriever.get_relevant_documents(question) |
|
|
|
return retrieved_results |
|
|
|
|
|
def main(user_input, history): |
|
retrieved_results = query_chroma_dynamic(user_input) |
|
if len(retrieved_results) > 0: |
|
context = '\n\n'.join([f"Document{i+1}:\n{doc.page_content}" for i, doc in enumerate(retrieved_results)]) |
|
final_response_message = [{ |
|
"role": "system", |
|
"content": f"{final_response_system_prompt}" |
|
}] |
|
else: |
|
context = "Sorry I couldnt find any documents from database" |
|
final_response_message = [{ |
|
"role": "system", |
|
"content": f"{fallback_system_prompt}" |
|
}] |
|
for i in history[-4:]: |
|
final_response_message.append(i) |
|
final_response_message.append({"role": "user", "content": f"\nDocuments:\n{context}\n\n{user_input}"}) |
|
|
|
final_response = generate_llm_response(final_response_message) |
|
history.append({"role": "user", "content": f"\nDocuments:\n{context}\n\n{user_input}"}) |
|
history.append({"role": "assistant", "content": f"{final_response}"}) |
|
return final_response, history |
|
|