|
|
|
import gc |
|
import os |
|
import time |
|
|
|
import rich |
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from loguru import logger |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation.utils import GenerationConfig |
|
|
|
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits" |
|
|
|
loc = snapshot_download(repo_id=model_name, local_dir="model") |
|
|
|
|
|
os.environ["TZ"] = "Asia/Shanghai" |
|
try: |
|
time.tzset() |
|
except Exception: |
|
|
|
logger.warning("Windows, cant run time.tzset()") |
|
|
|
model = None |
|
gc.collect() |
|
|
|
logger.info("start") |
|
has_cuda = torch.cuda.is_available() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) |
|
|
|
if has_cuda: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"model", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
load_in_8bit=True, |
|
trust_remote_code=True, |
|
|
|
|
|
) |
|
else: |
|
try: |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
|
"model", |
|
trust_remote_code=True, |
|
) |
|
except Exception as exc: |
|
logger.error(exc) |
|
logger.warning("Doesnt seem to load for CPU...") |
|
raise SystemExit(1) from exc |
|
|
|
model = model.eval() |
|
|
|
rich.print(f"{model=}") |
|
|
|
logger.info("done") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True |
|
) |
|
|
|
|
|
|
|
model.generation_config = GenerationConfig.from_pretrained( |
|
"baichuan-inc/Baichuan2-13B-Chat-4bits" |
|
) |
|
messages = [] |
|
messages.append({"role": "user", "content": "解释一下“温故而知新”"}) |
|
response = model.chat(tokenizer, messages) |
|
|
|
rich.print(response) |
|
|
|
logger.info(f"{response=}") |
|
|