File size: 4,546 Bytes
7b9ab52
 
 
 
 
 
71ffdaa
8d4c960
 
a759c01
 
7b9ab52
 
 
a759c01
0c65707
 
a759c01
 
 
 
 
 
 
 
 
 
 
 
7b9ab52
 
 
 
 
 
 
 
a759c01
7b9ab52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c65707
7b9ab52
a759c01
7b9ab52
 
 
 
a759c01
7b9ab52
 
 
 
 
 
a759c01
 
 
 
7b9ab52
a759c01
 
7b9ab52
a759c01
 
 
 
 
7b9ab52
 
 
 
 
 
 
 
 
a759c01
 
 
 
 
 
7b9ab52
a759c01
 
 
7b9ab52
a7268d5
 
a759c01
a7268d5
7b9ab52
 
 
 
 
 
 
 
 
 
a759c01
7b9ab52
 
a759c01
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import json
import os
import re
import uuid
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Constants
MODEL_NAME = "ethicsadvisorproject/Llama-2-7b-ethical-chat-finetune"
DB_DIR = 'user_data'  # Directory to store individual user data
os.makedirs(DB_DIR, exist_ok=True)  # Ensure the directory exists

# Load the Hugging Face model

st.set_page_config(page_title='Ethical GPT Assistant', layout='wide')
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir="/tmp")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto",
        offload_folder="/tmp"
    )
    return pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=200)

pipe = load_model()

def get_user_id():
    """Generate or retrieve a unique ID for the user."""
    if "user_id" not in st.session_state:
        st.session_state.user_id = str(uuid.uuid4())  # Generate a new UUID
    return st.session_state.user_id

def get_user_file(user_id):
    """Return the file path for a user's data file."""
    return os.path.join(DB_DIR, f"{user_id}.json")

def load_user_data(user_id):
    """Load chat history for the user."""
    user_file = get_user_file(user_id)
    if os.path.exists(user_file):
        with open(user_file, 'r') as file:
            return json.load(file)
    return {"chat_history": []}  # Default empty chat history

def save_user_data(user_id, data):
    """Save chat history for the user."""
    user_file = get_user_file(user_id)
    with open(user_file, 'w') as file:
        json.dump(data, file)

def main():
    #st.set_page_config(page_title='Ethical GPT Assistant', layout='wide')

    st.image("./logo/images.jpeg", use_column_width=True)

    intro = """
    ## Welcome to EthicsAdvisor

    Ethical GPT is an AI-powered chatbot designed to interact with you in an ethical, safe, and responsible manner. Our goal is to ensure that all responses provided by the assistant are respectful and considerate of various societal and ethical standards.
    Feel free to ask any questions, and rest assured that the assistant will provide helpful and appropriate responses.
    """
    st.markdown(intro)

    # Sidebar options
    st.sidebar.title("❄️EthicsAdvisor 📄")
    st.sidebar.caption("Make AI responses more ethical")

    with st.sidebar.expander("See fine-tuning info"):
        st.caption("Original Data: [Data](https://huggingface.co/datasets/MasahiroKaneko/eagle/)")
        st.caption("Modified Data: [Data](https://huggingface.co/datasets/ethicsadvisorproject/ethical_data_bigger/) 📝")
        st.caption("Used Model and Notebook: [Original model](https://huggingface.co/ethicsadvisorproject/Llama-2-7b-ethical-chat-finetune/) 🎈, Notebook used for fine-tuning [Notebook](https://colab.research.google.com/drive/1eAAjdwwD0i-i9-ehEJYUKXvZoYK0T3ue#scrollTo=ib_We3NLtj2E)")

    with st.sidebar.expander("ℹ️ **Take survey**"):
        st.markdown("You are welcome to give us your input on this research [here](https://forms.office.com/r/H4ARtETV2q).")

    # Initialize chat history
    user_id = get_user_id()
    user_data = load_user_data(user_id)

    if "messages" not in st.session_state:
        st.session_state.messages = user_data["chat_history"]

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # User input
    prompt = st.chat_input("What is up?")
    if prompt:
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        # Generate response
        response = pipe(f"<s>[INST] {prompt} [/INST]")
        response_text = response[0]["generated_text"].replace("<s>[INST]", "").replace("[/INST]", "").strip()

        with st.chat_message("assistant"):
            st.markdown(response_text)

        st.session_state.messages.append({"role": "assistant", "content": response_text})

        # Save updated chat history
        user_data["chat_history"] = st.session_state.messages
        save_user_data(user_id, user_data)

    # Clear Chat button
    if st.sidebar.button('Clear Chat'):
        st.session_state.messages = []
        user_data["chat_history"] = []
        save_user_data(user_id, user_data)
        st.experimental_rerun()

if __name__ == '__main__':
    main()