ramji-srotas's picture
Uploading files to have ingestion and all files
5e273da verified
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.schema import Document
import json
import re
from tqdm import tqdm
embedding_model = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-small-v0.1")
chroma_dir = "./chroma_trials_data_version_3"
chroma_db = Chroma(
embedding_function = embedding_model,
persist_directory=chroma_dir
)
def normalize_clinical_trial_data(data):
# Extract relevant sections
identification = data["protocolSection"]["identificationModule"]
description = data["protocolSection"]["descriptionModule"]
eligibility = data["protocolSection"]["eligibilityModule"]
locations = data["protocolSection"]["contactsLocationsModule"]["locations"][0]
inclusions_exclusions = eligibility.get("eligibilityCriteria", "").split("Exclusion Criteria:")
inclusions = ""
exclusions = ""
if len(inclusions_exclusions) >1:
exclusions = inclusions_exclusions[1]
inclusions = inclusions_exclusions[0].split("Inclusion Criteria:")
if len(inclusions)>0:
inclusions = inclusions[-1]
# Build normalized dictionary
normalized_data = {
"title": identification.get("officialTitle", ""),
"summary": description.get("briefSummary", ""),
"min_age": eligibility.get("minimumAge", ""),
"max_age": eligibility.get("maximumAge", ""),
"gender": eligibility.get("sex", ""),
"inclusions": inclusions,
"exclusions": exclusions,
"facility": locations.get("facility", ""),
"status": locations.get("status", ""),
"city": locations.get("city", ""),
"state": locations.get("state", ""),
"country": locations.get("country", ""),
"contacts": "\n".join([
f'Name: {contact.get("name", "")}, Role: {contact.get("role", "")}, Phone: {contact.get("phone", "")}, Email: {contact.get("email", "")}' for contact in locations.get("contacts", [])
])
}
return normalized_data
def store_data_in_chroma(raw_data):
documents = []
count = 0
for record in tqdm(raw_data):
try:
normalized_record = normalize_clinical_trial_data(record)
content = f"""Title: {normalized_record['title']}
Summary: {normalized_record['summary']}
Inclusions: {normalized_record['inclusions']}
Exclusions: {normalized_record['exclusions']}
Contacts: {normalized_record['contacts']}
Acceptable Age Range: {normalized_record['min_age']}- {normalized_record['max_age']}
"""
# Facility: {normalized_record['facility']}
# City: {normalized_record['city']}
# State: {normalized_record['state']}
# Country: {normalized_record['country']}
# Gender: {normalized_record['gender']}
# print(content)
metadata = {
"facility": normalized_record['facility'],
"status": normalized_record['status'],
"city": normalized_record['city'],
"state": normalized_record['state'],
"country": normalized_record['country']
}
documents.append(
Document(
page_content=content, metadata=metadata
)
)
count+=1
if count > 500:
break
except Exception as e:
print(e)
print("Document_size", len(documents))
chroma_db.add_documents(documents)
chroma_db.persist()
print('Data store in ChormaDB Successfully')
def get_unique_city_state_country():
results = chroma_db._collection.get(include=["metadatas"])
metadata = {'city': [], 'state':[], 'country':[]}
for doc in results['metadatas']:
metadata['city'].append(doc['city'])
metadata['state'].append(doc['state'])
metadata['country'].append(doc['country'])
return list(set(metadata['city'])),list(set(metadata['state'])), list(set(metadata['country']))
with open("ctg-studies.json", "r") as d:
raw_data = json.load(d)
store_data_in_chroma(raw_data)
city, state, country = get_unique_city_state_country()
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
metadata_extraction_system_prompt = f"""You are an advanced information extractor.
Your task is to analyze the user input and extract location details (city, state, or country).
Ground_Data:
1. city: {city},
2. state: {state},
3. country: {country}
Instructions:
1. If a specific location type (e.g., city, state, or country) is not mentioned in the input, set its value to "".
2. Always include the keys "city", "state", and "country" in the output.
3. Match the values of "city", "state", and "country" strictly with the corresponding categories in the Ground_Data:
- Assign a value to "city" only if it matches any entry in Ground_Data["city"].
- Assign a value to "state" only if it matches any entry in Ground_Data["state"].
- Assign a value to "country" only if it matches any entry in Ground_Data["country"].
4. If a term does not match any entry in the Ground_Data for its category, leave it as "".
5. Do not make assumptions or infer any details not explicitly stated in the input.
For Example:
Input: "Changchun City, China"
Output: {{"city": "Changchun", "state": "", "country": "China"}}
Reason: state is not available
Input: "What do you think about United States?"
Output: {{"city": "", "state": "", "country": "United States"}}
Wrong Output: {{"city": "", "state": "United States", "country": ""}}
Reason: United States is not available in state Ground_Data
Your response MUST be directly parsed using **json.loads** nothing else.
"""
final_response_system_prompt = """You are an AI assistant specialized in providing information about ongoing clinical trials.
You will assist users by extracting relevant details from the provided clinical trial documents.
Key Instructions:
1. Use only the information explicitly stated in the documents.
2. Do not rely on general knowledge or assumptions.
3. If the user requests information not covered by the documents, ask clarifying questions or inform them that the required data is not available.
4. When presenting information, include specific details like trial titles, eligibility criteria, contact details, and locations, ensuring they align with the users query.
5. Keep responses concise and tailored to the users request.
6. Avoid speculation or providing unrelated information.
Available Information:
1. Clinical trial details, including titles, summaries, eligibility criteria, exclusions, and contact information.
2. Contacts for trial coordination, including their roles, phone numbers, and emails.
Use these documents as your sole source of truth to address user queries.
"""
fallback_system_prompt = f"""
You are an AI assistant specialized in providing information about ongoing clinical trials.
You will assist users by extracting relevant details from the provided clinical trial documents.
If the documents are empty and user has any city, state or country specified in question
ask for some verification questions based on location
Ground_Data:
1. city: {city},
2. state: {state},
3. country: {country}
"""
def generate_llm_response(messages):
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
do_sample=False
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def query_chroma_dynamic(question):
search_kwargs = {"k": 5}
extracted_metadata = {}
messages = [
{
"role": "system",
"content": metadata_extraction_system_prompt
},
{"role": "user", "content": f"{question}"}
]
llm_response = generate_llm_response(messages)
print("Extraction LLM Response: ", llm_response)
try:
extracted_metadata = json.loads(llm_response.replace('`','').replace('json',''))
except Exception as e:
print(e, llm_response, type(llm_response))
if len(extracted_metadata) > 0:
cleaned_data = []
for k, v in extracted_metadata.items():
if len(v)>0:
cleaned_data.append({k:v})
if len(cleaned_data) > 1:
search_kwargs["filter"] = {"$and": cleaned_data}
elif len(cleaned_data) == 1:
search_kwargs["filter"] = cleaned_data[0]
else:
cleaned_data = extracted_metadata.copy()
retriever = chroma_db.as_retriever(
search_kwargs=search_kwargs
)
retrieved_results = retriever.get_relevant_documents(question)
if retrieved_results == 0:
retriever = chroma_db.as_retriever(
search_kwargs={"k":5}
)
retrieved_results = retriever.get_relevant_documents(question)
return retrieved_results
def main(user_input, history):
retrieved_results = query_chroma_dynamic(user_input)
if len(retrieved_results) > 0:
context = '\n\n'.join([f"Document{i+1}:\n{doc.page_content}" for i, doc in enumerate(retrieved_results)])
final_response_message = [{
"role": "system",
"content": f"{final_response_system_prompt}"
}]
else:
context = "Sorry I couldnt find any documents from database"
final_response_message = [{
"role": "system",
"content": f"{fallback_system_prompt}"
}]
for i in history[-4:]:
final_response_message.append(i)
final_response_message.append({"role": "user", "content": f"\nDocuments:\n{context}\n\n{user_input}"})
final_response = generate_llm_response(final_response_message)
history.append({"role": "user", "content": f"\nDocuments:\n{context}\n\n{user_input}"})
history.append({"role": "assistant", "content": f"{final_response}"})
return final_response, history