Spaces:
Sleeping
Sleeping
File size: 6,496 Bytes
17a7426 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import sys
import time
from pathlib import Path
from typing import Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.lora import GPT, Block, Config, merge_lora_weights
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
from scripts.prepare_alpaca import generate_prompt
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_query = True
lora_key = False
lora_value = True
lora_projection = False
lora_mlp = False
lora_head = False
def main(
prompt: str = "What food do llamas eat?",
input: str = "",
lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 200,
temperature: float = 0.8,
strategy: str = "auto",
devices: int = 1,
precision: Optional[str] = None,
) -> None:
"""Generates a response based on a given instruction and an optional input.
This script will only work with checkpoints from the instruction-tuned GPT-LoRA model.
See `finetune/lora.py`.
Args:
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
lora_path: Path to the checkpoint with trained adapter weights, which are the output of
`finetune/lora.py`.
checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
- gptq.int4: 4-bit quantization from GPTQ
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
strategy: Indicates the Fabric strategy setting to use.
devices: How many devices to use.
precision: Indicates the Fabric precision setting to use.
"""
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None:
if devices > 1:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
" --quantize flag."
)
if quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if strategy == "fsdp":
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_json(
checkpoint_dir / "lit_config.json",
r=lora_r,
alpha=lora_alpha,
dropout=lora_dropout,
to_query=lora_query,
to_key=lora_key,
to_value=lora_value,
to_projection=lora_projection,
to_mlp=lora_mlp,
to_head=lora_head,
)
if quantize is not None and devices > 1:
raise NotImplementedError
if quantize == "gptq.int4":
model_file = "lit_model_gptq.4bit.pth"
if not (checkpoint_dir / model_file).is_file():
raise ValueError("Please run `python quantize/gptq.py` first")
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
tokenizer = Tokenizer(checkpoint_dir)
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
lora_checkpoint = lazy_load(lora_path)
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(checkpoint)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
merge_lora_weights(model)
model = fabric.setup(model)
L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
CLI(main)
|