File size: 2,644 Bytes
1c776f7 5ee0c02 2c4a7ea dd48380 5ee0c02 dd48380 2c4a7ea 1c776f7 dd48380 1c776f7 2c4a7ea dd48380 5e66ec0 5ee0c02 5e66ec0 5ee0c02 5e69eac f3c0178 5e69eac 5ee0c02 5e69eac 5ee0c02 1c776f7 dd48380 5e66ec0 1c776f7 5e66ec0 1c776f7 dd48380 1c776f7 dd48380 fbcf846 dd48380 2c4a7ea dd48380 2c4a7ea 1c776f7 5ee0c02 1c776f7 dd48380 73bf78a 5e66ec0 d8e275b 5e66ec0 d8e275b 5e66ec0 5ee0c02 5e66ec0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import time
import requests
import gradio as gr
from huggingface_hub import get_inference_endpoint
endpoint_name = os.getenv('ENDPOINT_NAME')
endpoint_url = os.getenv('ENDPOINT_URL')
personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')
turn_breaker = os.getenv('TURN_BREAKER')
system_symbol = os.getenv('SYSTEM_SYMBOL')
user_symbol = os.getenv('USER_SYMBOL')
assistant_symbol = os.getenv('ASSISTANT_SYMBOL')
headers = {
"Accept" : "application/json",
"Authorization": f"Bearer {personal_secret_token}",
"Content-Type": "application/json"
}
def query(payload):
response = requests.post(endpoint_url, headers=headers, json=payload)
return response.json()
endpoint = get_inference_endpoint(endpoint_name, token=personal_secret_token)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_new_tokens,
temperature,
top_p,
progress=gr.Progress()
):
progress(0, desc="Starting")
if endpoint.status != "running":
endpoint.client.post(json={
"inputs": "wake up!"
})
progress(0.25, desc="Waking up model")
while endpoint.status != "running":
time.sleep(1)
progress(0.5, desc="Generating")
all_messages = [system_message]
for val in history:
if val[0]:
all_messages.append(user_symbol+val[0])
if val[1]:
all_messages.append(assistant_symbol+val[1])
all_messages.append(user_symbol+message)
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
temperature=temperature
)
response = query({
"inputs": turn_breaker.join(all_messages),
"parameters": generation_kwargs
})
progress(1, desc="Generating")
return response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="请你扮演一个开心,积极的角色,名叫贺英旭,今年26岁,工作是程序员。你需要以这个身份和朋友们进行对话。", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
show_progress="full"
)
if __name__ == "__main__":
demo.launch() |