|
import os |
|
import logging |
|
import re |
|
from langchain.vectorstores import Chroma |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_groq import ChatGroq |
|
from langchain.schema import Document |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
import chardet |
|
import gradio as gr |
|
import pandas as pd |
|
import json |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def clean_api_key(key): |
|
return ''.join(c for c in key if ord(c) < 128) |
|
|
|
|
|
api_key = os.getenv("GROQ_API_KEY") |
|
if not api_key: |
|
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.") |
|
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") |
|
api_key = clean_api_key(api_key).strip() |
|
|
|
|
|
def clean_text(text): |
|
return text.encode("ascii", errors="ignore").decode() |
|
|
|
|
|
def load_documents(file_paths): |
|
docs = [] |
|
for file_path in file_paths: |
|
ext = os.path.splitext(file_path)[-1].lower() |
|
try: |
|
if ext == ".csv": |
|
|
|
with open(file_path, 'rb') as f: |
|
result = chardet.detect(f.read()) |
|
encoding = result['encoding'] |
|
data = pd.read_csv(file_path, encoding=encoding) |
|
for index, row in data.iterrows(): |
|
content = clean_text(row.to_string()) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
elif ext == ".json": |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
if isinstance(data, list): |
|
for entry in data: |
|
content = clean_text(json.dumps(entry)) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
elif isinstance(data, dict): |
|
content = clean_text(json.dumps(data)) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
elif ext == ".txt": |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = clean_text(f.read()) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
else: |
|
logger.warning(f"Unsupported file format: {file_path}") |
|
except Exception as e: |
|
logger.error(f"Error processing file {file_path}: {e}") |
|
logger.debug("Exception details:", exc_info=True) |
|
return docs |
|
|
|
|
|
def ensure_complete_sentences(text): |
|
|
|
sentences = re.findall(r'[^.!?]*[.!?]', text) |
|
if sentences: |
|
|
|
return ' '.join(sentences).strip() |
|
return text |
|
|
|
|
|
def is_valid_input(text): |
|
""" |
|
Checks if the input text is meaningful. |
|
Returns True if the text contains alphabetic characters and is of sufficient length. |
|
""" |
|
if not text or text.strip() == "": |
|
return False |
|
|
|
if not re.search('[A-Za-z]', text): |
|
return False |
|
|
|
if len(text.strip()) < 5: |
|
return False |
|
return True |
|
|
|
|
|
def initialize_llm(model, temperature, max_tokens): |
|
try: |
|
|
|
prompt_allocation = int(max_tokens * 0.2) |
|
response_max_tokens = max_tokens - prompt_allocation |
|
if response_max_tokens <= 50: |
|
raise ValueError("max_tokens is too small to allocate for the response.") |
|
|
|
llm = ChatGroq( |
|
model=model, |
|
temperature=temperature, |
|
max_tokens=response_max_tokens, |
|
api_key=api_key |
|
) |
|
logger.info("LLM initialized successfully.") |
|
return llm |
|
except Exception as e: |
|
logger.error(f"Error initializing LLM: {e}") |
|
raise |
|
|
|
|
|
def create_rag_pipeline(file_paths, model, temperature, max_tokens): |
|
try: |
|
llm = initialize_llm(model, temperature, max_tokens) |
|
docs = load_documents(file_paths) |
|
if not docs: |
|
logger.warning("No documents were loaded. Please check your file paths and formats.") |
|
return None, "No documents were loaded. Please check your file paths and formats." |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
splits = text_splitter.split_documents(docs) |
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
vectorstore = Chroma.from_documents( |
|
documents=splits, |
|
embedding=embedding_model, |
|
persist_directory="/tmp/chroma_db" |
|
) |
|
vectorstore.persist() |
|
logger.info("Vectorstore initialized and persisted successfully.") |
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
custom_prompt_template = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=""" |
|
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity. |
|
|
|
Context: |
|
{context} |
|
|
|
Question: |
|
{question} |
|
|
|
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly. |
|
""" |
|
) |
|
|
|
rag_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
chain_type_kwargs={"prompt": custom_prompt_template} |
|
) |
|
logger.info("RAG pipeline created successfully.") |
|
return rag_chain, "Pipeline created successfully." |
|
except Exception as e: |
|
logger.error(f"Error creating RAG pipeline: {e}") |
|
logger.debug("Exception details:", exc_info=True) |
|
return None, f"Error creating RAG pipeline: {e}" |
|
|
|
|
|
POSITIVE_WORDS = { |
|
"good", "great", "excellent", "amazing", "wonderful", "fantastic", "positive", |
|
"helpful", "satisfied", "happy", "love", "liked", "enjoyed", "beneficial", |
|
"superb", "awesome", "nice", "brilliant", "favorable", "pleased" |
|
} |
|
|
|
NEGATIVE_WORDS = { |
|
"bad", "terrible", "awful", "poor", "disappointed", "unsatisfied", "hate", |
|
"hated", "dislike", "dislikes", "worst", "negative", "not helpful", "frustrated", |
|
"unhappy", "dissatisfied", "unfortunate", "horrible", "annoyed", "problem", "issues" |
|
} |
|
|
|
|
|
def handle_feedback(feedback_text): |
|
""" |
|
Handles user feedback by analyzing its sentiment and providing a dynamic response. |
|
Stores the feedback in a temporary file for persistence during the session. |
|
|
|
Parameters: |
|
- feedback_text (str): The feedback provided by the user. |
|
|
|
Returns: |
|
- str: Acknowledgment message based on feedback sentiment. |
|
""" |
|
if feedback_text and feedback_text.strip() != "": |
|
|
|
feedback_lower = feedback_text.lower() |
|
|
|
|
|
positive_count = sum(word in feedback_lower for word in POSITIVE_WORDS) |
|
negative_count = sum(word in feedback_lower for word in NEGATIVE_WORDS) |
|
|
|
|
|
if positive_count > negative_count: |
|
sentiment = "positive" |
|
acknowledgment = "Thank you for your positive feedback! We're glad to hear that you found our service helpful." |
|
elif negative_count > positive_count: |
|
sentiment = "negative" |
|
acknowledgment = "We're sorry to hear that you're not satisfied. Your feedback is valuable to us, and we'll strive to improve." |
|
else: |
|
sentiment = "neutral" |
|
acknowledgment = "Thank you for your feedback. We appreciate your input." |
|
|
|
|
|
logger.info(f"User Feedback: {feedback_text} | Sentiment: {sentiment}") |
|
|
|
|
|
try: |
|
with open("/tmp/user_feedback.txt", "a") as f: |
|
f.write(f"{feedback_text} | Sentiment: {sentiment}\n") |
|
logger.debug("Feedback stored successfully in /tmp/user_feedback.txt.") |
|
except Exception as e: |
|
logger.error(f"Error storing feedback: {e}") |
|
|
|
return acknowledgment |
|
else: |
|
return "No feedback provided." |
|
|
|
|
|
|
|
file_paths = ['AIChatbot.csv'] |
|
model = "llama3-8b-8192" |
|
temperature = 0.7 |
|
max_tokens = 500 |
|
|
|
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) |
|
if rag_chain is None: |
|
logger.error("Failed to initialize RAG pipeline at startup.") |
|
|
|
|
|
|
|
def answer_question(model, temperature, max_tokens, question, feedback): |
|
|
|
if not is_valid_input(question): |
|
logger.info("Received invalid input from user.") |
|
return "Please provide a valid question or input containing meaningful text.", "" |
|
|
|
|
|
if rag_chain is None: |
|
logger.error("RAG pipeline is not initialized.") |
|
return "The system is currently unavailable. Please try again later.", "" |
|
|
|
try: |
|
answer = rag_chain.run(question) |
|
logger.info("Question answered successfully.") |
|
|
|
complete_answer = ensure_complete_sentences(answer) |
|
|
|
|
|
feedback_response = handle_feedback(feedback) |
|
|
|
return complete_answer, feedback_response |
|
except Exception as e_inner: |
|
logger.error(f"Error during RAG pipeline execution: {e_inner}") |
|
logger.debug("Exception details:", exc_info=True) |
|
return f"Error during RAG pipeline execution: {e_inner}", "" |
|
|
|
|
|
def gradio_interface(model, temperature, max_tokens, question, feedback): |
|
|
|
|
|
return answer_question(model, temperature, max_tokens, question, feedback) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Model Name", |
|
value=model, |
|
placeholder="e.g., llama3-8b-8192" |
|
), |
|
gr.Slider( |
|
label="Temperature", |
|
minimum=0, |
|
maximum=1, |
|
step=0.01, |
|
value=temperature, |
|
info="Controls the randomness of the response. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." |
|
), |
|
gr.Slider( |
|
label="Max Tokens", |
|
minimum=200, |
|
maximum=2048, |
|
step=1, |
|
value=max_tokens, |
|
info="Determines the maximum number of tokens in the response. Higher values allow for longer answers." |
|
), |
|
gr.Textbox( |
|
label="Question", |
|
placeholder="e.g., What is box breathing and how does it help reduce anxiety?" |
|
), |
|
gr.Textbox( |
|
label="Feedback", |
|
placeholder="Provide your feedback here...", |
|
lines=2 |
|
) |
|
], |
|
outputs=[ |
|
"text", |
|
"text" |
|
], |
|
title="Daily Wellness AI", |
|
description="Ask questions about daily wellness and get detailed solutions.", |
|
examples=[ |
|
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?", "Great explanation!"], |
|
["llama3-8b-8192", 0.6, 600, "Provide a daily wellness schedule incorporating box breathing techniques.", "Very helpful, thank you!"] |
|
], |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
|
|