Spaces:
Sleeping
Sleeping
File size: 3,712 Bytes
cc5b602 6f619d7 6386510 c4592e6 51a7d9e 6386510 6d1d1e9 51a7d9e 6386510 e6367a7 6d1d1e9 51a7d9e bd34f0b d6256ce 6386510 bd34f0b 6d1d1e9 bd34f0b 51a7d9e 6386510 51a7d9e bd34f0b 51a7d9e da59244 6386510 6d1d1e9 0fffdb4 6d1d1e9 6386510 bbd8145 4ed884e 285cc01 c4592e6 6d1d1e9 4ed884e c4592e6 285cc01 27dc368 51a7d9e 6386510 51a7d9e 82b38de 51a7d9e 4ed884e 51a7d9e 6d1d1e9 51a7d9e bd34f0b 4ed884e bd34f0b 4ed884e bd34f0b 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
MODEL_LIST = ["openbmb/MiniCPM-1B-sft-bf16", "openbmb/MiniCPM-S-1B-sft"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID", None)
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>MiniCPM-1B-chat</center></h1>"
DESCRIPTION = f"""
<h3>MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></h3>
"""
PLACEHOLDER = """
<center>
<p>MiniCPM is an End-Size LLM developed by ModelBest Inc. and TsinghuaNLP, with only 1.2B parameters excluding embeddings.</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map='auto',
low_cpu_mem_usage=True,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
def stream_chat(
message: str,
history: list,
temperature: float = 0.8,
max_new_tokens: int = 1024,
top_p: float = 1.0,
top_k: int = 20,
penalty: float = 1.2
):
print(f'message: {message}')
print(f'history: {history}')
resp, history = model.chat(
tokenizer,
query = message,
history = history,
max_length = max_new_tokens,
do_sample = False if temperature == 0 else True,
top_p = top_p,
top_k = top_k,
temperature = temperature,
)
return resp
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="Max Length",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition penalty",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|