|
from llama_cpp import Llama |
|
import streamlit as st |
|
from langchain.llms.base import LLM |
|
|
|
from llama_index import PromptHelper |
|
from typing import Optional, List, Mapping, Any |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
import pandas as pd |
|
|
|
|
|
st.set_page_config(page_title='Mental Heallth chatbot', page_icon=':robot_face:', layout='wide') |
|
|
|
|
|
MODEL_NAME = 'TheBloke/MelloGPT-GGUF' |
|
MODEL_PATH = 'TheBloke/MelloGPT-GGUF' |
|
KNOWLEDGE_BASE_FILE = "mentalhealth.csv" |
|
|
|
|
|
NUM_THREADS = 8 |
|
MAX_INPUT_SIZE = 2048 |
|
NUM_OUTPUT = 256 |
|
CHUNK_OVERLAP_RATIO = 0.10 |
|
|
|
|
|
try: |
|
prompt_helper = PromptHelper(MAX_INPUT_SIZE, NUM_OUTPUT, CHUNK_OVERLAP_RATIO) |
|
except Exception as e: |
|
CHUNK_OVERLAP_RATIO = 0.2 |
|
prompt_helper = PromptHelper(MAX_INPUT_SIZE, NUM_OUTPUT, CHUNK_OVERLAP_RATIO) |
|
|
|
embed_model = LangchainEmbedding(HuggingFaceEmbeddings()) |
|
|
|
class CustomLLM(LLM): |
|
model_name = MODEL_NAME |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
p = f"Human: {prompt} Assistant: " |
|
prompt_length = len(p) |
|
llm = Llama(model_path=MODEL_PATH, n_threads=NUM_THREADS) |
|
try: |
|
output = llm(p, max_tokens=512, stop=["Human:"], echo=True)['choices'][0]['text'] |
|
response = output[prompt_length:] |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
except Exception as e: |
|
st.error("An error occurred while processing your request. Please try again.") |
|
|
|
@property |
|
def _identifying_params(self) -> Mapping[str, Any]: |
|
return {"name_of_model": self.model_name} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return CustomLLM() |
|
|
|
@st.cache_data |
|
def load_knowledge_base(): |
|
df = pd.read_csv(KNOWLEDGE_BASE_FILE) |
|
return dict(zip(df['Questions'].str.lower(), df['Answers'])) |
|
|
|
def clear_convo(): |
|
st.session_state['messages'] = [] |
|
|
|
def init(): |
|
if 'messages' not in st.session_state: |
|
st.session_state['messages'] = [] |
|
|
|
|
|
if __name__ == '__main__': |
|
init() |
|
knowledge_base = load_knowledge_base() |
|
llm = load_model() |
|
|
|
clear_button = st.sidebar.button("Clear Conversation") |
|
if clear_button: |
|
clear_convo() |
|
|
|
user_input = st.text_input("Enter your query:", key="user_input") |
|
if user_input: |
|
user_input = user_input.lower() |
|
answer = knowledge_base.get(user_input) |
|
if answer: |
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
st.session_state.messages.append({"role": "assistant", "content": answer}) |
|
else: |
|
llm._call(prompt=user_input) |
|
|
|
for message in st.session_state.messages: |
|
with st.container(): |
|
st.markdown(f"**{message['role'].title()}**: {message['content']}") |
|
|
|
|