File size: 3,325 Bytes
1d1a2f4
 
 
ad9d775
68e412e
 
1697d4e
1d1a2f4
 
a4bc9f8
 
1d1a2f4
 
 
 
a1c0854
a4bc9f8
 
 
1d1a2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from llama_cpp import Llama
import streamlit as st
from langchain.llms.base import LLM
#from llama_index.embeddings import LangchainEmbedding
from langchain.embeddings import HuggingFaceEmbeddings
from llama_index.embeddings.langchain import LangchainEmbedding
from llama_index.core import PromptHelper
from typing import Optional, List, Mapping, Any
import pandas as pd
# Load model directly
from transformers import AutoModel
# Set the page config as the first command
st.set_page_config(page_title='Mental Heallth chatbot', page_icon=':robot_face:', layout='wide')

# Define constants
MODEL_NAME = 'TheBloke/MelloGPT-GGUF'
MODEL_PATH = AutoModel.from_pretrained("TheBloke/MelloGPT-GGUF")

#MODEL_PATH = 'TheBloke/MelloGPT-GGUF'
KNOWLEDGE_BASE_FILE = "mentalhealth.csv"

# Configuration
NUM_THREADS = 8
MAX_INPUT_SIZE = 2048
NUM_OUTPUT = 256
CHUNK_OVERLAP_RATIO = 0.10

# Initialize prompt helper with fallback on exception
try:
    prompt_helper = PromptHelper(MAX_INPUT_SIZE, NUM_OUTPUT, CHUNK_OVERLAP_RATIO)
except Exception as e:
    CHUNK_OVERLAP_RATIO = 0.2
    prompt_helper = PromptHelper(MAX_INPUT_SIZE, NUM_OUTPUT, CHUNK_OVERLAP_RATIO)

embed_model = LangchainEmbedding(HuggingFaceEmbeddings())

class CustomLLM(LLM):
    model_name = MODEL_NAME

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        p = f"Human: {prompt} Assistant: "
        prompt_length = len(p)
        llm = Llama(model_path=MODEL_PATH, n_threads=NUM_THREADS)
        try:
            output = llm(p, max_tokens=512, stop=["Human:"], echo=True)['choices'][0]['text']
            response = output[prompt_length:]
            st.session_state.messages.append({"role": "user", "content": prompt})
            st.session_state.messages.append({"role": "assistant", "content": response})
        except Exception as e:
            st.error("An error occurred while processing your request. Please try again.")

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": self.model_name}

    @property
    def _llm_type(self) -> str:
        return "custom"

# Cache functions using the new methods
@st.cache_resource
def load_model():
    return CustomLLM()

@st.cache_data
def load_knowledge_base():
    df = pd.read_csv(KNOWLEDGE_BASE_FILE)
    return dict(zip(df['Questions'].str.lower(), df['Answers']))

def clear_convo():
    st.session_state['messages'] = []

def init():
    if 'messages' not in st.session_state:
        st.session_state['messages'] = []

# Main function
if __name__ == '__main__':
    init()
    knowledge_base = load_knowledge_base()
    llm = load_model()

    clear_button = st.sidebar.button("Clear Conversation")
    if clear_button:
        clear_convo()

    user_input = st.text_input("Enter your query:", key="user_input")
    if user_input:
        user_input = user_input.lower()
        answer = knowledge_base.get(user_input)
        if answer:
            st.session_state.messages.append({"role": "user", "content": user_input})
            st.session_state.messages.append({"role": "assistant", "content": answer})
        else:
            llm._call(prompt=user_input)

    for message in st.session_state.messages:
        with st.container():
            st.markdown(f"**{message['role'].title()}**: {message['content']}")