Vicuna_ChatBot / app.py
RinInori's picture
Update app.py
77a2725
raw
history blame
1.56 kB
import torch
from peft import PeftModel
import transformers
import gradio as gr
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from transformers import Trainer
#BASE_MODEL = "TheBloke/vicuna-7B-1.1-HF"
BASE_MODEL = "RinInori/vicuna_finetuned_6_sentiments"
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map = "auto",
offload_folder="./cache",
)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
def format_prompt(prompt: str) -> str:
return f"### Human: {prompt}\n### Assistant:"
generation_config = GenerationConfig(
max_new_tokens=128,
temperature=0.2,
repetition_penalty=1.0,
)
def generate_text(prompt: str):
formatted_prompt = format_prompt(prompt)
inputs = tokenizer(
formatted_prompt,
padding=False,
add_special_tokens=False,
return_tensors="pt"
).to(model.device)
with torch.inference_mode():
tokens = model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(tokens[0], skip_special_tokens=True)
assistant_index = response.find("### Assistant:") + len("### Assistant:")
return response[assistant_index:].strip()
iface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Chatbot",
description="This vicuna app is using this model: https://huggingface.co/RinInori/vicuna_finetuned_6_sentiments"
)
iface.launch()