Spaces:
Paused
Paused
File size: 3,954 Bytes
8de5029 a1bd8b6 8de5029 927b5de 7563a34 1874bf4 8de5029 27d5e20 8de5029 a4bb3a3 5d679d7 8de5029 b2dd6fb ca44a43 88023ca ca44a43 b5b8829 77d217a 1590544 8de5029 27d5e20 8de5029 27d5e20 8de5029 4694c68 8de5029 927b5de 1874bf4 8de5029 1874bf4 8de5029 1874bf4 edc6972 927b5de 8de5029 1874bf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
title = "Welcome to Tonic's 🐋🐳Orca-2-13B!"
description = "You can use [🐋🐳microsoft/Orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b) via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/TonicsOrca2?duplicate=true) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Big thanks to the HuggingFace Organisation for the Community Grant."
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "microsoft/Orca-2-13b"
# offload_folder = './model_weights'
# if not os.path.exists(offload_folder):
# os.makedirs(offload_folder)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(torch.bfloat16)
model = model.to(device)
class OrcaChatBot:
def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
self.model = model
self.tokenizer = tokenizer
self.system_message = system_message
self.conversation_history = None
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
# Prepare the prompt
prompt = f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" if self.conversation_history is None else self.conversation_history + f"<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
# Encode the prompt
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
input_ids = inputs["input_ids"].to(self.model.device)
# Generate a response
output_ids = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True # Enable sampling-based generation
)
# Decode the generated response
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Update conversation history
self.conversation_history = self.tokenizer.decode(output_ids[0], skip_special_tokens=False)
return response
Orca_bot = OrcaChatBot(model, tokenizer)
def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
full_message = f"{system_message}\n{user_message}" if system_message else user_message
return Orca_bot.predict(full_message, temperature, max_new_tokens, top_p, repetition_penalty)
iface = gr.Interface(
fn=gradio_predict,
title=title,
description=description,
inputs=[
gr.Textbox(label="Your Message", type="text", lines=3),
gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", type="text", lines=2),
gr.Slider(label="Max new tokens", value=1200, minimum=25, maximum=4096, step=1),
gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
],
outputs="text",
theme="ParityError/Anime"
)
# Launch the Gradio interface
iface.launch() |