Spaces:
Paused
Paused
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
# Model name and configuration | |
model_name = "ruslanmv/Medical-Llama3-8B" | |
device_map = "auto" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
# Load the model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=bnb_config, | |
trust_remote_code=True, | |
use_cache=False, | |
device_map=device_map, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
# Set pad_token_id to eos_token_id if None | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
# Define the chat template | |
chat_template = """<|im_start|>system | |
{system} | |
<|im_end|> | |
<|im_start|>user | |
{user} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
tokenizer.chat_template = chat_template | |
# Function to generate a response | |
def askme(question): | |
sys_message = """ | |
You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and | |
provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help. | |
""" | |
# Structure messages for the chat | |
messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}] | |
# Apply the chat template | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
# Generate response | |
outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True) | |
# Decode and clean up the response | |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
if "<|im_start|>assistant" in response_text: | |
response_text = response_text.split("<|im_start|>assistant")[-1].strip() | |
return response_text | |
# Example usage | |
question = """ | |
I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, | |
increased sensitivity to cold, and dry, itchy skin. | |
Could these symptoms be related to hypothyroidism? | |
If so, what steps should I take to get a proper diagnosis and discuss treatment options? | |
""" | |
print(askme(question)) | |