metadata
library_name: peft
base_model: mistralai/Mixtral-8x7B-Instruct-v0.1
datasets:
- ruslanmv/ai-medical-chatbot
license: apache-2.0
Model Card for Medical-Mixtral-7B-v1.5k
Model Description
The Medical-Mixtral-7B-v1.5k is a fine-tuned Mixtral model for answering medical assistance questions. This model is a novel version of mistralai/Mixtral-8x7B-Instruct-v0.1, adapted to a subset of 1.5k records from the AI Medical Chatbot dataset, which contains 250k records. The purpose of this model is to provide a ready chatbot to answer questions related to medical assistance.
Model Sources [optional]
How to Get Started with the Model
Installation
pip install -qU transformers==4.36.2 datasets python-dotenv peft bitsandbytes accelerate
Use the code below to get started with the model.
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging, BitsAndBytesConfig
import os, torch
# Define the name of your fine-tuned model
finetuned_model = 'ruslanmv/Medical-Mixtral-7B-v1.5k'
# Load fine-tuned model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
model_pretrained = AutoModelForCausalLM.from_pretrained(
finetuned_model,
load_in_4bit=True,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(finetuned_model, trust_remote_code=True)
# Set pad_token_id to eos_token_id
model_pretrained.config.pad_token_id = tokenizer.eos_token_id
pipe = pipeline(task="text-generation", model=model_pretrained, tokenizer=tokenizer, max_length=100)
def build_prompt(question):
prompt=f"[INST]@Enlighten. {question} [/INST]"
return prompt
question = "What does abutment of the nerve root mean?"
prompt = build_prompt(question)
# Generate text based on the prompt
result = pipe(prompt)[0]
generated_text = result['generated_text']
# Remove the prompt from the generated text
generated_text = generated_text.replace(prompt, "", 1).strip()
print(generated_text)
Framework versions
- PEFT 0.10.0