MedX / app.py
Gopal Agarwal
model updated
71bb1ec
raw
history blame
2.75 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
model_name = "ruslanmv/Medical-Llama3-8B"
# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
# Adjust configuration based on available hardware
if device == "cuda":
device_map = 'auto'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
else:
device_map = None
bnb_config = None
# Load the model with adjusted parameters
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
use_cache=False,
device_map=device_map,
low_cpu_mem_usage=True if device == "cuda" else False
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = """
{% for message in messages %}
{% if message['role'] == 'system' %}
System: {{ message['content'] }}
{% elif message['role'] == 'user' %}
Human: {{ message['content'] }}
{% elif message['role'] == 'assistant' %}
Assistant: {{ message['content'] }}
{% endif %}
{% endfor %}
Human: {{ messages[-1]['content'] }}
Assistant:"""
def process_medical_history(prescription_details):
sys_message = '''
You are an AI Medical Assistant. Given a string of prescription details, generate a structured medical history output.
Include the following sections with appropriate headings:
1. Date of Prescription
2. Duration of Medicines
3. Problems Recognized
4. Test Results
Format the output clearly with each section having its own heading and content on a new line.
Do not include unnecessary details like additional notes, extra tokens and markers like <|endoftext|> or <|pad|>.
'''
question = f"Please format the following prescription details into a structured medical history: {prescription_details}"
messages = [
{"role": "system", "content": sys_message},
{"role": "user", "content": question}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=300, use_cache=True)
response_text = tokenizer.batch_decode(outputs)[0].strip()
answer = response_text.split('Assistant:')[-1].strip()
# Clean up the output
answer = answer.replace('<|endoftext|>', '').replace('<|pad|>', '').strip()
return answer
demo = gr.Interface(fn=process_medical_history, inputs="text", outputs="text")
demo.launch()