Chatbot / model_setup.py
anhvv200053's picture
Update model_setup.py
dc745aa verified
raw
history blame
847 Bytes
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# Cấu hình BitsAndBytes để tải mô hình 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype="float16",
bnb_4bit_use_double_quant=False,
)
# Thiết lập mô hình và tokenizer
def load_model():
model = AutoModelForCausalLM.from_pretrained(
"anhvv200053/Vinallama-2-7B-updated1-instruction-v2",
quantization_config=bnb_config,
device_map={"": 0},
token = token
)
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained('anhvv200053/Vinallama-2-7B-updated1-instruction-v2', trust_remote_code=True, use_fast=True, token = token)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer