Spaces:
Runtime error
Runtime error
import json | |
import os | |
from pprint import pprint | |
import bitsandbytes as bnb | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import transformers | |
from datasets import load_dataset | |
from huggingface_hub import notebook_login | |
from peft import ( | |
LoraConfig, | |
PeftConfig, | |
PeftModel, | |
get_peft_model, | |
prepare_model_for_kbit_training, | |
) | |
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
) | |
import gradio as gr | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
PEFT_MODEL = "cdy3870/Falcon-Fetch-Bot" | |
config = PeftConfig.from_pretrained(PEFT_MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = PeftModel.from_pretrained(model, PEFT_MODEL) | |
generation_config = model.generation_config | |
generation_config.max_new_tokens = 200 | |
generation_config.temperature = 0.7 | |
generation_config.top_p = 0.7 | |
generation_config.num_return_sequences = 1 | |
generation_config.pad_token_id = tokenizer.eos_token_id | |
generation_config.eos_token_id = tokenizer.eos_token_id | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
) | |
def query_model(message, history): | |
prompt = f""" | |
<human>: {message} | |
<assistant>: | |
""".strip() | |
result = pipeline( | |
prompt, | |
generation_config=generation_config, | |
) | |
# parsed_result = result[0]["generated_text"].split("<assistant>:")[1][1:] | |
return parsed_result | |
gr.ChatInterface(query_model, textbox=gr.Textbox(placeholder="Ask anything about Fetch!", container=False, scale=7),).launch() |