import warnings import torch from peft import PeftModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig warnings.filterwarnings("ignore") model_name = "google/gemma-2b" adapters_name = "./lora_weights" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) hf_token = "YOUR_TOKEN_HERE" tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0}, token=hf_token) model = PeftModel.from_pretrained(model, adapters_name) model = model.merge_and_unload() def format_query(query): text = f"Instruction: {query} \n\n Output: " device = "cuda:0" inputs = tokenizer(text, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_new_tokens=120) return tokenizer.decode(outputs[0], skip_special_tokens=False).split("Output:")[1].split("")[0].split("Instruction:")[0] if __name__ == "__main__": while True: query = input("> ") result = format_query(query) print(f"Result: {result}") print("="*100)