|
import warnings |
|
|
|
import torch |
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
model_name = "google/gemma-2b" |
|
adapters_name = "./lora_weights" |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
hf_token = "YOUR_TOKEN_HERE" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0}, token=hf_token) |
|
model = PeftModel.from_pretrained(model, adapters_name) |
|
model = model.merge_and_unload() |
|
|
|
|
|
def format_query(query): |
|
text = f"Instruction: {query} \n\n Output: " |
|
|
|
device = "cuda:0" |
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
|
|
outputs = model.generate(**inputs, max_new_tokens=120) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=False).split("Output:")[1].split("<eos>")[0].split("Instruction:")[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
while True: |
|
query = input("> ") |
|
result = format_query(query) |
|
print(f"Result: {result}") |
|
print("="*100) |