Spaces:
Running
Running
import streamlit as st | |
import requests | |
import json | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
def reset_conversation(): | |
''' | |
Resets Conversation | |
''' | |
st.session_state.conversation = [] | |
st.session_state.messages = [] | |
return None | |
# Define model links for Hugging Face models | |
model_links = { | |
"Mistral": "mistralai/Mistral-7B-Instruct-v0.2", | |
"Gemma-7B": "google/gemma-7b-it", | |
"Gemma-2B": "google/gemma-2b-it", | |
"Zephyr-7B-β": "HuggingFaceH4/zephyr-7b-beta", | |
"Nous-Hermes-2-Yi-34B": "NousResearch/Nous-Hermes-2-Yi-34B" | |
} | |
# Define model info for all models | |
model_info = { | |
"Mistral": { | |
'description': "The Mistral model is a Large Language Model (LLM) developed by Mistral AI.", | |
'logo': 'https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp' | |
}, | |
"Gemma-7B": { | |
'description': "The Gemma-7B model is a Large Language Model (LLM) developed by Google with 7 billion parameters.", | |
'logo': 'https://pbs.twimg.com/media/GG3sJg7X0AEaNIq.jpg' | |
}, | |
"Gemma-2B": { | |
'description': "The Gemma-2B model is a Large Language Model (LLM) developed by Google with 2 billion parameters.", | |
'logo': 'https://pbs.twimg.com/media/GG3sJg7X0AEaNIq.jpg' | |
}, | |
"Zephyr-7B-β": { | |
'description': "The Zephyr-7B-β model is a Large Language Model (LLM) developed by HuggingFace.", | |
'logo': 'https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha/resolve/main/thumbnail.png' | |
}, | |
"Nous-Hermes-2-Yi-34B": { | |
'description': "The Nous Hermes model is a Large Language Model (LLM) developed by Nous Research with 34 billion parameters.", | |
'logo': 'https://example.com/nous_hermes_logo.png' | |
} | |
} | |
# Function to interact with Hugging Face models | |
def interact_with_huggingface_model(messages, model): | |
# Add your code here to interact with the Hugging Face model | |
pass | |
# Function to interact with the Together API model | |
def interact_with_together_api(messages): | |
all_messages = [] | |
if not messages: # If history is empty | |
all_messages.append({"role": "user", "content": ""}) | |
history = [("", "")] # Add dummy values to prevent unpacking error | |
for human, assistant in messages: | |
all_messages.append({"role": "user", "content": human}) | |
all_messages.append({"role": "assistant", "content": assistant}) | |
all_messages.append({"role": "user", "content": messages[-1][1]}) | |
url = "https://api.together.xyz/v1/chat/completions" | |
payload = { | |
"model": "NousResearch/Nous-Hermes-2-Yi-34B", | |
"temperature": 1.05, | |
"top_p": 0.9, | |
"top_k": 50, | |
"repetition_penalty": 1, | |
"n": 1, | |
"messages": all_messages, | |
"stream_tokens": True, | |
} | |
TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY') | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
"Authorization": f"Bearer {TOGETHER_API_KEY}", | |
} | |
response = requests.post(url, json=payload, headers=headers, stream=True) | |
response.raise_for_status() # Ensure HTTP request was successful | |
for line in response.iter_lines(): | |
if line: | |
decoded_line = line.decode('utf-8') | |
# Check for the completion signal | |
if decoded_line == "data: [DONE]": | |
yield entire_assistant_response # Yield the entire response at the end | |
break | |
try: | |
# Decode and strip any SSE format specific prefix ("data: ") | |
if decoded_line.startswith("data: "): | |
decoded_line = decoded_line.replace("data: ", "") | |
chunk_data = json.loads(decoded_line) | |
content = chunk_data['choices'][0]['delta']['content'] | |
entire_assistant_response += content # Aggregate content | |
yield entire_assistant_response | |
except json.JSONDecodeError: | |
print(f"Invalid JSON received: {decoded_line}") | |
continue | |
except KeyError as e: | |
print(f"KeyError encountered: {e}") | |
continue | |
# Create sidebar with model selection dropdown and temperature slider | |
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys())) | |
temperature = st.sidebar.slider('Select Temperature', 0.0, 1.0, 0.5) | |
st.sidebar.button('Reset Chat', on_click=reset_conversation) | |
# Display model description and logo | |
st.sidebar.write(f"You're now chatting with **{selected_model}**") | |
st.sidebar.markdown(model_info[selected_model]['description']) | |
st.sidebar.image(model_info[selected_model]['logo']) | |
st.sidebar.markdown("*Generated content may be inaccurate or false.*") | |
st.sidebar.markdown("\nLearn how to build this chatbot [here](https://ngebodh.github.io/projects/2024-03-05/).") | |
st.sidebar.markdown("\nRun into issues? Try the [back-up](https://huggingface.co/spaces/ngebodh/SimpleChatbot-Backup).") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Accept user input | |
if prompt := st.chat_input(f"Hi, I'm {selected_model}, ask me a question"): | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append(("user", prompt)) | |
# Interact with selected model | |
if selected_model == "Nous-Hermes-2-Yi-34B": | |
stream = interact_with_together_api(st.session_state.messages) | |
else: | |
interact_with_huggingface_model(st.session_state.messages, model_links[selected_model]) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
response = "" | |
for chunk in stream: | |
response = chunk | |
st.markdown(response) | |
st.session_state.messages.append(("assistant", response)) | |