|
import marimo |
|
|
|
__generated_with = "0.9.14" |
|
app = marimo.App(width="medium") |
|
|
|
|
|
@app.cell |
|
def __(): |
|
import marimo as mo |
|
import os |
|
from huggingface_hub import InferenceClient |
|
return InferenceClient, mo, os |
|
|
|
|
|
@app.cell |
|
def __(): |
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" |
|
return (MODEL_NAME,) |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(MODEL_NAME, mo): |
|
mo.md(f""" |
|
# Chat with **{MODEL_NAME}** |
|
""") |
|
return |
|
|
|
|
|
@app.cell |
|
def __(max_tokens, mo, system_message, temperature, top_p): |
|
mo.hstack( |
|
[ |
|
system_message, |
|
mo.vstack([temperature, top_p, max_tokens], align="end"), |
|
], |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(mo, respond): |
|
chat = mo.ui.chat( |
|
model=respond, |
|
prompts=["Tell me a joke.", "What is the square root of {{number}}?"], |
|
) |
|
chat |
|
return (chat,) |
|
|
|
|
|
@app.cell |
|
def __(InferenceClient, MODEL_NAME, os): |
|
""" |
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.26.2/en/guides/inference |
|
""" |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
print("HF_TOKEN not set, may have limited access.") |
|
|
|
client = InferenceClient( |
|
MODEL_NAME, |
|
token=hf_token, |
|
) |
|
return client, hf_token |
|
|
|
|
|
@app.cell |
|
def __(client, mo): |
|
|
|
system_message = mo.ui.text_area( |
|
value="You are a friendly Chatbot.", |
|
label="System message", |
|
) |
|
max_tokens = mo.ui.slider( |
|
start=1, |
|
stop=2048, |
|
value=512, |
|
step=1, |
|
label="Max new tokens", |
|
show_value=True, |
|
) |
|
temperature = mo.ui.slider( |
|
start=0.1, |
|
stop=4.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
show_value=True, |
|
) |
|
top_p = mo.ui.slider( |
|
start=0.1, |
|
stop=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
show_value=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def respond(messages: list[mo.ai.ChatMessage], config): |
|
chat_messages = [{"role": "system", "content": system_message.value}] |
|
|
|
for message in messages: |
|
parts = [] |
|
|
|
parts.append({"type": "text", "text": message.content}) |
|
|
|
|
|
if message.attachments: |
|
for attachment in message.attachments: |
|
content_type = attachment.content_type or "" |
|
|
|
if content_type.startswith("image"): |
|
parts.append( |
|
{ |
|
"type": "image_url", |
|
"image_url": {"url": attachment.url}, |
|
} |
|
) |
|
else: |
|
raise ValueError( |
|
f"Unsupported content type {content_type}" |
|
) |
|
|
|
chat_messages.append({"role": message.role, "content": parts}) |
|
|
|
response = client.chat_completion( |
|
chat_messages, |
|
max_tokens=max_tokens.value, |
|
temperature=temperature.value, |
|
top_p=top_p.value, |
|
stream=False, |
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
return max_tokens, respond, system_message, temperature, top_p |
|
|
|
|
|
@app.cell |
|
def __(): |
|
|
|
|
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|