Spaces:
Running
Running
File size: 2,269 Bytes
4eaa76b 7e12b4f 4eaa76b 786bb0f 4527045 adec948 4527045 7e12b4f 786bb0f 7e12b4f 6003553 7e12b4f 786bb0f 4527045 7e12b4f 4527045 7e12b4f 4527045 7e12b4f 4527045 7e12b4f 4959bf1 4527045 6e74755 4959bf1 4527045 6e74755 7377b55 7e12b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os
import gradio as gr
from openai import OpenAI
from optillm.moa import mixture_of_agents
from optillm.mcts import chat_with_mcts
from optillm.bon import best_of_n_sampling
API_KEY = os.environ.get("HF_TOKEN")
def respond(
message,
history: list[tuple[str, str]],
model,
approach,
system_message,
max_tokens,
temperature,
top_p,
):
client = OpenAI(api_key=API_KEY, base_url="https://api-inference.huggingface.co/models/"+model+"/v1")
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# response = ""
final_response = mixture_of_agents(system_message, message, client, model)
return final_response
# for message in client.chat_completion(
# messages,
# max_tokens=max_tokens,
# stream=True,
# temperature=temperature,
# top_p=top_p,
# ):
# token = message.choices[0].delta.content
# response += token
# yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Dropdown(
["meta-llama/Meta-Llama-3.1-70B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct", "HuggingFaceH4/zephyr-7b-beta"],
value="meta-llama/Meta-Llama-3.1-70B-Instruct", label="Model", info="Choose the base model"
),
gr.Dropdown(
["bon", "mcts", "moa"], value="moa", label="Approach", info="Choose the approach"
),
gr.Textbox(value="", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch() |