|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
import torch |
|
|
|
model_name = "ruslanmv/Medical-Llama3-8B" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
if device == "cuda": |
|
device_map = 'auto' |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
else: |
|
device_map = None |
|
bnb_config = None |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=bnb_config, |
|
trust_remote_code=True, |
|
use_cache=False, |
|
device_map=device_map, |
|
low_cpu_mem_usage=True if device == "cuda" else False |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
tokenizer.chat_template = """ |
|
{% for message in messages %} |
|
{% if message['role'] == 'system' %} |
|
System: {{ message['content'] }} |
|
{% elif message['role'] == 'user' %} |
|
Human: {{ message['content'] }} |
|
{% elif message['role'] == 'assistant' %} |
|
Assistant: {{ message['content'] }} |
|
{% endif %} |
|
{% endfor %} |
|
Human: {{ messages[-1]['content'] }} |
|
Assistant:""" |
|
|
|
def process_medical_history(prescription_details): |
|
sys_message = ''' |
|
You are an AI Medical Assistant. Given a string of prescription details, generate a structured medical history output. |
|
Include the following sections with appropriate headings: |
|
1. Date of Prescription |
|
2. Duration of Medicines |
|
3. Problems Recognized |
|
4. Test Results |
|
|
|
Format the output clearly with each section having its own heading and content on a new line. |
|
Do not include unnecessary details like additional notes, extra tokens and markers like <|endoftext|> or <|pad|>. |
|
''' |
|
|
|
question = f"Please format the following prescription details into a structured medical history: {prescription_details}" |
|
|
|
messages = [ |
|
{"role": "system", "content": sys_message}, |
|
{"role": "user", "content": question} |
|
] |
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_new_tokens=300, use_cache=True) |
|
|
|
response_text = tokenizer.batch_decode(outputs)[0].strip() |
|
answer = response_text.split('Assistant:')[-1].strip() |
|
|
|
|
|
answer = answer.replace('<|endoftext|>', '').replace('<|pad|>', '').strip() |
|
|
|
return answer |
|
|
|
demo = gr.Interface(fn=process_medical_history, inputs="text", outputs="text") |
|
demo.launch() |