Phoenix21's picture
tried using nltk to improve input handling
9a3085c verified
raw
history blame
7.05 kB
import os
import logging
import re
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import chardet
import gradio as gr
import pandas as pd
import json
from nltk.corpus import words
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def clean_api_key(key):
return ''.join(c for c in key if ord(c) < 128)
# Load the GROQ API key
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
api_key = clean_api_key(api_key).strip()
def clean_text(text):
return text.encode("ascii", errors="ignore").decode()
def load_documents(file_paths):
docs = []
for file_path in file_paths:
ext = os.path.splitext(file_path)[-1].lower()
try:
if ext == ".csv":
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
encoding = result['encoding']
data = pd.read_csv(file_path, encoding=encoding)
for _, row in data.iterrows():
content = clean_text(row.to_string())
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".json":
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
for entry in data:
content = clean_text(json.dumps(entry))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif isinstance(data, dict):
content = clean_text(json.dumps(data))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".txt":
with open(file_path, 'r', encoding='utf-8') as f:
content = clean_text(f.read())
docs.append(Document(page_content=content, metadata={"source": file_path}))
else:
logger.warning(f"Unsupported file format: {file_path}")
except Exception as e:
logger.error(f"Error processing file {file_path}: {e}")
return docs
# Enhanced input validation
# Load NLTK word list
try:
english_words = set(words.words())
except LookupError:
import nltk
nltk.download('words')
english_words = set(words.words())
def is_valid_input(text):
"""Validate the user's input question."""
if not text or text.strip() == "":
return False, "Input cannot be empty. Please provide a meaningful question."
if len(text.strip()) < 2:
return False, "Input is too short. Please provide more context or details."
# Check for valid words
words_in_text = re.findall(r'\b\w+\b', text.lower())
recognized_words = [word for word in words_in_text if word in english_words]
if not recognized_words:
return False, "Input appears unclear. Please use valid words in your question."
return True, "Valid input."
def initialize_llm(model, temperature, max_tokens):
prompt_allocation = int(max_tokens * 0.2)
response_max_tokens = max_tokens - prompt_allocation
if response_max_tokens <= 50:
raise ValueError("max_tokens too small.")
llm = ChatGroq(
model=model,
temperature=temperature,
max_tokens=response_max_tokens,
api_key=api_key
)
return llm
def create_rag_pipeline(file_paths, model, temperature, max_tokens):
llm = initialize_llm(model, temperature, max_tokens)
docs = load_documents(file_paths)
if not docs:
return None, "No documents were loaded."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding_model,
persist_directory="/tmp/chroma_db"
)
retriever = vectorstore.as_retriever()
custom_prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an AI assistant specialized in daily wellness. Provide a concise, thorough, and stand-alone answer to the user's question based on the given context. Include relevant examples or schedules where beneficial. **When listing steps or guidelines, format them as a numbered list with appropriate markdown formatting.** The final answer should be coherent, self-contained, and end with a complete sentence.
Context:
{context}
Question:
{question}
Final Answer:
"""
)
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": custom_prompt_template}
)
return rag_chain, "Pipeline created successfully."
file_paths = ['AIChatbot.csv']
model = "llama3-8b-8192"
temperature = 0.7
max_tokens = 500
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
def answer_question(model, temperature, max_tokens, question):
is_valid, message = is_valid_input(question)
if not is_valid:
return message
if rag_chain is None:
return "The system is currently unavailable. Please try again later."
try:
answer = rag_chain.run(question)
return answer.strip()
except Exception as e_inner:
logger.error(f"Error: {e_inner}")
return "An error occurred while processing your request."
def gradio_interface(model, temperature, max_tokens, question):
return answer_question(model, temperature, max_tokens, question)
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Model Name", value=model),
gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=temperature),
gr.Slider(label="Max Tokens", minimum=200, maximum=2048, step=1, value=max_tokens),
gr.Textbox(label="Question", placeholder="e.g., What is box breathing and how does it help reduce anxiety?")
],
outputs=gr.Markdown(label="Answer"),
title="Daily Wellness AI",
description="Ask questions about daily wellness and receive a concise, complete answer.",
examples=[
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"],
["llama3-8b-8192", 0.6, 600, "Give me a weekly fitness schedule incorporating mindfulness exercises."]
],
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)