|
import os |
|
import logging |
|
import re |
|
import nltk |
|
import spacy |
|
from nltk.tokenize import sent_tokenize |
|
from langchain.vectorstores import Chroma |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_groq import ChatGroq |
|
from langchain.schema import Document |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
import chardet |
|
import gradio as gr |
|
import pandas as pd |
|
import json |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
|
|
try: |
|
nlp = spacy.load("en_core_web_sm") |
|
except OSError: |
|
|
|
from spacy.cli import download |
|
download("en_core_web_sm") |
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def clean_api_key(key): |
|
return ''.join(c for c in key if ord(c) < 128) |
|
|
|
|
|
api_key = os.getenv("GROQ_API_KEY") |
|
if not api_key: |
|
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.") |
|
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") |
|
api_key = clean_api_key(api_key).strip() |
|
|
|
|
|
def clean_text(text): |
|
return text.encode("ascii", errors="ignore").decode() |
|
|
|
|
|
def load_documents(file_paths): |
|
docs = [] |
|
for file_path in file_paths: |
|
ext = os.path.splitext(file_path)[-1].lower() |
|
try: |
|
if ext == ".csv": |
|
|
|
with open(file_path, 'rb') as f: |
|
result = chardet.detect(f.read()) |
|
encoding = result['encoding'] |
|
data = pd.read_csv(file_path, encoding=encoding) |
|
for index, row in data.iterrows(): |
|
content = clean_text(row.to_string()) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
elif ext == ".json": |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
if isinstance(data, list): |
|
for entry in data: |
|
content = clean_text(json.dumps(entry)) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
elif isinstance(data, dict): |
|
content = clean_text(json.dumps(data)) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
elif ext == ".txt": |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = clean_text(f.read()) |
|
docs.append(Document(page_content=content, metadata={"source": file_path})) |
|
else: |
|
logger.warning(f"Unsupported file format: {file_path}") |
|
except Exception as e: |
|
logger.error(f"Error processing file {file_path}: {e}") |
|
return docs |
|
|
|
|
|
def ensure_complete_sentences(text): |
|
sentences = sent_tokenize(text) |
|
if sentences: |
|
return ' '.join(sentences).strip() |
|
return text |
|
|
|
|
|
def is_valid_input_nlp(text, threshold=0.5): |
|
""" |
|
Validates input text using spaCy's NLP capabilities. |
|
|
|
Parameters: |
|
- text (str): The input text to validate. |
|
- threshold (float): The minimum ratio of meaningful tokens required. |
|
|
|
Returns: |
|
- bool: True if the input is valid, False otherwise. |
|
""" |
|
if not text or text.strip() == "": |
|
return False |
|
doc = nlp(text) |
|
meaningful_tokens = [token for token in doc if token.is_alpha] |
|
if not meaningful_tokens: |
|
return False |
|
ratio = len(meaningful_tokens) / len(doc) |
|
return ratio >= threshold |
|
|
|
|
|
def estimate_prompt_tokens(prompt): |
|
""" |
|
Estimates the number of tokens in the prompt. |
|
This is a placeholder function. Replace it with actual token estimation logic. |
|
|
|
Parameters: |
|
- prompt (str): The prompt text. |
|
|
|
Returns: |
|
- int: Estimated number of tokens. |
|
""" |
|
return len(prompt.split()) |
|
|
|
|
|
def initialize_llm(model, temperature, max_tokens, prompt_template): |
|
try: |
|
|
|
estimated_prompt_tokens = estimate_prompt_tokens(prompt_template) |
|
|
|
|
|
response_max_tokens = max_tokens - estimated_prompt_tokens |
|
|
|
if response_max_tokens <= 100: |
|
raise ValueError("max_tokens is too small to allocate for the response.") |
|
|
|
llm = ChatGroq( |
|
model=model, |
|
temperature=temperature, |
|
max_tokens=response_max_tokens, |
|
api_key=api_key |
|
) |
|
logger.debug("LLM initialized successfully.") |
|
return llm |
|
except Exception as e: |
|
logger.error(f"Error initializing LLM: {e}") |
|
raise |
|
|
|
|
|
def create_rag_pipeline(file_paths, model, temperature, max_tokens): |
|
try: |
|
|
|
custom_prompt_template = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=""" |
|
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity. |
|
|
|
Context: |
|
{context} |
|
|
|
Question: |
|
{question} |
|
|
|
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly. |
|
""" |
|
) |
|
|
|
|
|
estimated_prompt_tokens = estimate_prompt_tokens(custom_prompt_template.template) |
|
|
|
|
|
llm = initialize_llm(model, temperature, max_tokens, custom_prompt_template.template) |
|
|
|
|
|
docs = load_documents(file_paths) |
|
if not docs: |
|
logger.warning("No documents were loaded. Please check your file paths and formats.") |
|
return None, "No documents were loaded. Please check your file paths and formats." |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
splits = text_splitter.split_documents(docs) |
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
vectorstore = Chroma.from_documents( |
|
documents=splits, |
|
embedding=embedding_model, |
|
persist_directory="./chroma_db" |
|
) |
|
vectorstore.persist() |
|
logger.debug("Vectorstore initialized and persisted successfully.") |
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
rag_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
chain_type_kwargs={"prompt": custom_prompt_template} |
|
) |
|
logger.debug("RAG pipeline created successfully.") |
|
return rag_chain, "Pipeline created successfully." |
|
except Exception as e: |
|
logger.error(f"Error creating RAG pipeline: {e}") |
|
return None, f"Error creating RAG pipeline: {e}" |
|
|
|
|
|
def handle_feedback(feedback_text): |
|
""" |
|
Handles user feedback by logging it. |
|
In a production environment, consider storing feedback in a database or external service. |
|
|
|
Parameters: |
|
- feedback_text (str): The feedback provided by the user. |
|
|
|
Returns: |
|
- str: Acknowledgment message. |
|
""" |
|
if feedback_text and feedback_text.strip() != "": |
|
|
|
logger.info(f"User Feedback: {feedback_text}") |
|
return "Thank you for your feedback!" |
|
else: |
|
return "No feedback provided." |
|
|
|
|
|
def answer_question(file_paths, model, temperature, max_tokens, question, feedback): |
|
|
|
if not is_valid_input_nlp(question): |
|
return "Please provide a valid question or input containing meaningful text.", "" |
|
|
|
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) |
|
if rag_chain is None: |
|
return message, "" |
|
|
|
try: |
|
answer = rag_chain.run(question) |
|
logger.debug("Question answered successfully.") |
|
|
|
complete_answer = ensure_complete_sentences(answer) |
|
|
|
|
|
feedback_response = handle_feedback(feedback) |
|
|
|
return complete_answer, feedback_response |
|
except Exception as e: |
|
logger.error(f"Error during RAG pipeline execution: {e}") |
|
return f"Error during RAG pipeline execution: {e}", "" |
|
|
|
|
|
def gradio_interface(model, temperature, max_tokens, question, feedback): |
|
file_paths = ['AIChatbot.csv'] |
|
return answer_question(file_paths, model, temperature, max_tokens, question, feedback) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Model Name", |
|
value="llama3-8b-8192", |
|
placeholder="e.g., llama3-8b-8192" |
|
), |
|
gr.Slider( |
|
label="Temperature", |
|
minimum=0, |
|
maximum=1, |
|
step=0.01, |
|
value=0.7, |
|
info="Controls the randomness of the response. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." |
|
), |
|
gr.Slider( |
|
label="Max Tokens", |
|
minimum=200, |
|
maximum=2048, |
|
step=1, |
|
value=500, |
|
info="Determines the maximum number of tokens in the response. Higher values allow for longer answers." |
|
), |
|
gr.Textbox( |
|
label="Question", |
|
placeholder="e.g., What is box breathing and how does it help reduce anxiety?" |
|
), |
|
gr.Textbox( |
|
label="Feedback", |
|
placeholder="Provide your feedback here...", |
|
lines=2 |
|
) |
|
], |
|
outputs=[ |
|
"text", |
|
"text" |
|
], |
|
title="Daily Wellness AI", |
|
description="Ask questions about daily wellness and get detailed solutions.", |
|
examples=[ |
|
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?", "Great explanation!"], |
|
["llama3-8b-8192", 0.6, 600, "Provide a daily wellness schedule incorporating box breathing techniques.", "Very helpful, thank you!"] |
|
], |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
|
|