Text Generation
mingkuan's picture
update sample run.
6182537
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from llama_condense_monkey_patch import replace_llama_with_condense
from peft import PeftConfig
from peft import PeftModel
import torch
## config device params & load model
peft_model_id = "mingkuan/longchat-7b-qlora-customer-support"
base_model_id = "lmsys/longchat-7b-16k"
config = AutoConfig.from_pretrained(base_model_id)
replace_llama_with_condense(config.rope_condense_ratio)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)
kwargs = {"torch_dtype": torch.float16}
kwargs["device_map"] = "auto"
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
return_dict=True,
trust_remote_code=True,
quantization_config=nf4_config,
load_in_4bit=True,
**kwargs
)
model = PeftModel.from_pretrained(model, peft_model_id)
def generate_prompt(query):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
You are a customer support assistant that can extract user request intent and category, and then provide appropriate answers. If the user input is related to customer support domain, please try to generate a json string that contains extracted category and intent, and the proper response.
If user input is unrelated to customer support domain, please try to answer it in natural language.
Example run:
Input: Would it be possible to cancel the order I made?
Output: "Category": "ORDER", "Intent": "cancel_order", "Answer": "Sure, I definitely can help you with that. Can you provide me your order number for the cancelation?"
### Input:
{query}
"""
def getLLMResponse(prompt):
device = "cuda"
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
output = model.generate(inputs=input_ids, temperature=0.5, max_new_tokens=256)
promptLen = len(prompt)
response = tokenizer.decode(output[0], skip_special_tokens=True)[promptLen:] ## omit the user input part
return response
query = 'help me to setup a new shipping address?'
response = getLLMResponse(generate_prompt(query))
print(f'\nUserInput:{query}\n\nLLM:\n{response}\n\n')