zavavan's picture
Update app.py
d7f4b9f verified
# File: app.py
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import torch
import regex as re
# Load PEFT adapter configuration
peft_config = PeftConfig.from_pretrained("unica/CLiMA")
# BitsAndBytes 4-bit config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Most efficient for LLMs
bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 or float16 depending on your GPU
bnb_4bit_use_double_quant=True
)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
quantization_config=bnb_config,
device_map="auto"
)
# Load adapter weights
model = PeftModel.from_pretrained(base_model, "unica/CLiMA")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
prompt_instruction_drug_reviews = f"""Given a drug review enclosed in triple quotes and a pair of entities E1 corresponding to the drug name and E2 corresponding to the treated condition, classify the relation holding between E1 and E2.
The relations are identified with 9 labels from 0 to 8. The meaning of the labels is the following:
0 means that E1 causes E2
1 means that E2 causes E1
2 means that E1 enables E2
3 means that E2 enables E1
4 means that E1 prevents E2
5 means that E2 prevents E1
6 means that E1 hinders E2
7 means that E2 hinders E1
8 means that E1 and E2 are in a relation different than any of the previous ones.
Given X the label that you predicted, for the output use the format LABEL: X
"""
# Format prompt
def format_prompt(user_input, entity1, entity2):
#return f"Identify causal relations in the following clinical narrative:\n\n{user_input}\n\nEntity 1: {entity1}\nEntity 2: {entity2}\n\nCausal relations:"
text = user_input
prompt_text = f"Text:'''{text}'''"
e1 = entity1
e2 = entity2
prompt_entities = f"\nEntities: E1: '''{e1}''', E2: '''{e2}'''"
full_prompt = f"<USER> {prompt_instruction_drug_reviews} {prompt_text} {prompt_entities} <ASSISTANT>"
return full_prompt
# Prediction function
def generate_relations(text, entity1, entity2):
answer_label_regex_pattern = re.compile(r'LABEL:?\s?(\d+)')
prompt = format_prompt(text, entity1, entity2)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
modelOut = response[len(prompt):].strip() # remove prompt from output if echoed
answer_match = answer_label_regex_pattern.search(modelOut)
if answer_match:
if answer_match.group(1)=='0':
return f"""'{entity1}' causes '{entity2}'"""
elif answer_match.group(1)=='1':
return f"""'{entity2}' causes '{entity1}'"""
elif answer_match.group(1)=='2':
return f"""'{entity1}' enables '{entity2}'"""
elif answer_match.group(1)=='3':
return f"""'{entity2}' enables '{entity1}'"""
elif answer_match.group(1)=='4':
return f"""'{entity1}' prevents '{entity2}'"""
elif answer_match.group(1)=='5':
return f"""'{entity2}' prevents '{entity1}'"""
elif answer_match.group(1)=='6':
return f"""'{entity1}' hinders '{entity2}'"""
elif answer_match.group(1)=='7':
return f"""'{entity2}' hinders '{entity1}'"""
elif answer_match.group(1)=='8':
return f"""No causal relation between '{entity1}' and '{entity2}'"""
else:
return 'No causal relation could be extracted'
# Gradio UI
demo = gr.Interface(
fn=generate_relations,
inputs=[
gr.Textbox(lines=10, label="Clinical Note or Drug Review Text"),
gr.Textbox(label="Entity 1 (e.g., Drug)"),
gr.Textbox(label="Entity 2 (e.g., Condition or Symptom)")
],
outputs=gr.Textbox(label="Extracted Causal Relations"),
title="Causal Relation Extractor with MedLlama",
description="Paste your clinical note or drug review, and specify two target entities. This AI agent extracts drug-condition or symptom causal relations using a fine-tuned LLM adapter model.",
examples=[
["Odynophagia: Was presumed due to mucositis from recent chemotherapy.", "chemotherapy", "mucositis"],
["patient's wife noticed erythema on patient's face. On [**3-27**]the visiting nurse [**First Name (Titles) 8706**][**Last Name (Titles)11282**]of a rash on his arms as well. The patient was noted to be febrile and was admitted to the [**Company 191**] Firm. In the EW, patient's Dilantin was discontinued and he was given Tegretol instead.", "Dilantin", "erythema on patient's face"],
["i had a urinary tract infection so bad that when i pee it smells but when i started taking ciprofloxacin it worked it’s a good medicine for a urinary tract infections.","ciprofloxacin","urinary tract infection"],
["when i first started using ziana, i only had acne in between my eyebrows, chin, and the nose area. my acne worsened while using it and then it got better. but after about 4 months of using it, it became ineffective. so i now have acne between my eyebrows, chin, cheeks, forehead, and the nose area. its great at first but after a while it made my face even worse than before i used the product.","ziana","acne"]
]
)
demo.launch()