Spaces:
Running
Running
import streamlit as st | |
import json | |
import os | |
import re | |
import uuid | |
import time | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Constants | |
MODEL_NAME = "ethicsadvisorproject/Llama-2-7b-ethical-chat-finetune" | |
DB_DIR = 'user_data' # Directory to store individual user data | |
os.makedirs(DB_DIR, exist_ok=True) # Ensure the directory exists | |
# Load the Hugging Face model | |
st.set_page_config(page_title='Ethical GPT Assistant', layout='wide') | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir="/tmp") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
offload_folder="/tmp" | |
) | |
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=200) | |
pipe = load_model() | |
def get_user_id(): | |
"""Generate or retrieve a unique ID for the user.""" | |
if "user_id" not in st.session_state: | |
st.session_state.user_id = str(uuid.uuid4()) # Generate a new UUID | |
return st.session_state.user_id | |
def get_user_file(user_id): | |
"""Return the file path for a user's data file.""" | |
return os.path.join(DB_DIR, f"{user_id}.json") | |
def load_user_data(user_id): | |
"""Load chat history for the user.""" | |
user_file = get_user_file(user_id) | |
if os.path.exists(user_file): | |
with open(user_file, 'r') as file: | |
return json.load(file) | |
return {"chat_history": []} # Default empty chat history | |
def save_user_data(user_id, data): | |
"""Save chat history for the user.""" | |
user_file = get_user_file(user_id) | |
with open(user_file, 'w') as file: | |
json.dump(data, file) | |
def main(): | |
#st.set_page_config(page_title='Ethical GPT Assistant', layout='wide') | |
st.image("./logo/images.jpeg", use_column_width=True) | |
intro = """ | |
## Welcome to EthicsAdvisor | |
Ethical GPT is an AI-powered chatbot designed to interact with you in an ethical, safe, and responsible manner. Our goal is to ensure that all responses provided by the assistant are respectful and considerate of various societal and ethical standards. | |
Feel free to ask any questions, and rest assured that the assistant will provide helpful and appropriate responses. | |
""" | |
st.markdown(intro) | |
# Sidebar options | |
st.sidebar.title("❄️EthicsAdvisor 📄") | |
st.sidebar.caption("Make AI responses more ethical") | |
with st.sidebar.expander("See fine-tuning info"): | |
st.caption("Original Data: [Data](https://huggingface.co/datasets/MasahiroKaneko/eagle/)") | |
st.caption("Modified Data: [Data](https://huggingface.co/datasets/ethicsadvisorproject/ethical_data_bigger/) 📝") | |
st.caption("Used Model and Notebook: [Original model](https://huggingface.co/ethicsadvisorproject/Llama-2-7b-ethical-chat-finetune/) 🎈, Notebook used for fine-tuning [Notebook](https://colab.research.google.com/drive/1eAAjdwwD0i-i9-ehEJYUKXvZoYK0T3ue#scrollTo=ib_We3NLtj2E)") | |
with st.sidebar.expander("ℹ️ **Take survey**"): | |
st.markdown("You are welcome to give us your input on this research [here](https://forms.office.com/r/H4ARtETV2q).") | |
# Initialize chat history | |
user_id = get_user_id() | |
user_data = load_user_data(user_id) | |
if "messages" not in st.session_state: | |
st.session_state.messages = user_data["chat_history"] | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
prompt = st.chat_input("What is up?") | |
if prompt: | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate response | |
response = pipe(f"<s>[INST] {prompt} [/INST]") | |
response_text = response[0]["generated_text"].replace("<s>[INST]", "").replace("[/INST]", "").strip() | |
with st.chat_message("assistant"): | |
st.markdown(response_text) | |
st.session_state.messages.append({"role": "assistant", "content": response_text}) | |
# Save updated chat history | |
user_data["chat_history"] = st.session_state.messages | |
save_user_data(user_id, user_data) | |
# Clear Chat button | |
if st.sidebar.button('Clear Chat'): | |
st.session_state.messages = [] | |
user_data["chat_history"] = [] | |
save_user_data(user_id, user_data) | |
st.experimental_rerun() | |
if __name__ == '__main__': | |
main() | |