CyberChitChat / Inference.py
pocketmonkey's picture
initial commit
3e0712e verified
import warnings
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
warnings.filterwarnings("ignore")
model_name = "google/gemma-2b"
adapters_name = "./lora_weights"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
hf_token = "YOUR_TOKEN_HERE"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0}, token=hf_token)
model = PeftModel.from_pretrained(model, adapters_name)
model = model.merge_and_unload()
def format_query(query):
text = f"Instruction: {query} \n\n Output: "
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=120)
return tokenizer.decode(outputs[0], skip_special_tokens=False).split("Output:")[1].split("<eos>")[0].split("Instruction:")[0]
if __name__ == "__main__":
while True:
query = input("> ")
result = format_query(query)
print(f"Result: {result}")
print("="*100)