Spaces:
Running
Running
import pandas as pd | |
import os | |
from sentence_transformers import SentenceTransformer | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_chroma import Chroma | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
import gradio as gr | |
import logging | |
# Set up basic logging (optional, but useful) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
try: | |
# Load the data - check for the file path | |
df = pd.read_csv('./Mental_Health_FAQ.csv') | |
context_data = [] | |
for i in range(len(df)): | |
context = f"Question: {df.iloc[i]['Questions']} Answer: {df.iloc[i]['Answers']}" | |
context_data.append(context) | |
# Embed the contexts | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
context_embeddings = embedding_model.encode(context_data) | |
# Get the API Key - important to check this is set | |
groq_key = os.environ.get('new_chatAPI_key') | |
if not groq_key: | |
raise ValueError("Groq API key not found in environment variables.") | |
# LLM used for RAG | |
llm = ChatGroq(model="llama-3.3-70b-versatile",api_key=groq_key) | |
# Embedding model | |
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
# Create the Vector Store! | |
vectorstore = Chroma( | |
collection_name="medical_dataset_store", | |
embedding_function=embed_model, | |
) | |
# Add data to vector store | |
vectorstore.add_texts(context_data) | |
retriever = vectorstore.as_retriever() | |
# Create the prompt template | |
template = ("""You are a mental health professional. | |
Use the provided context to answer the question. | |
If you don't know the answer, say so. Explain your answer in detail. | |
Do not discuss the context in your response; just provide the answer directly. | |
Context: {context} | |
Question: {question} | |
Answer:""") | |
rag_prompt = PromptTemplate.from_template(template) | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
def rag_memory_stream(message, history): | |
partial_text = "" | |
for new_text in rag_chain.stream(message): | |
partial_text += new_text | |
yield partial_text | |
examples = [ | |
"I am not in a good mood", | |
"what is the possible symptompts of depression?" | |
] | |
description = "Real-time AI App with Groq API and LangChain to Answer medical questions" | |
title = "ThriveTalk Expert :) Try me!" | |
demo = gr.ChatInterface(fn=rag_memory_stream, | |
type="messages", | |
title=title, | |
description=description, | |
fill_height=True, | |
examples=examples, | |
theme="glass", | |
) | |
except Exception as e: | |
logging.error(f"An error occurred during initialization: {e}") | |
# If there is an error then return a dummy error text to tell user | |
def error_function(message, history): | |
yield "An error has occurred. Please check the logs" | |
demo = gr.ChatInterface(fn=error_function, | |
type="messages", | |
title="ThriveTalk :(Ask me", | |
description="Please check the logs", | |
fill_height=True, | |
theme="glass", | |
) | |
if __name__ == "__main__": | |
demo.launch() |