Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import re | |
from langchain_community.vectorstores import Chroma | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_groq import ChatGroq | |
from langchain.schema import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
import chardet | |
import gradio as gr | |
import pandas as pd | |
import json | |
from nltk.corpus import words | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def clean_api_key(key): | |
return ''.join(c for c in key if ord(c) < 128) | |
# Load the GROQ API key | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
api_key = clean_api_key(api_key).strip() | |
def clean_text(text): | |
return text.encode("ascii", errors="ignore").decode() | |
def load_documents(file_paths): | |
docs = [] | |
for file_path in file_paths: | |
ext = os.path.splitext(file_path)[-1].lower() | |
try: | |
if ext == ".csv": | |
with open(file_path, 'rb') as f: | |
result = chardet.detect(f.read()) | |
encoding = result['encoding'] | |
data = pd.read_csv(file_path, encoding=encoding) | |
for _, row in data.iterrows(): | |
content = clean_text(row.to_string()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".json": | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
for entry in data: | |
content = clean_text(json.dumps(entry)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif isinstance(data, dict): | |
content = clean_text(json.dumps(data)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".txt": | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = clean_text(f.read()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
else: | |
logger.warning(f"Unsupported file format: {file_path}") | |
except Exception as e: | |
logger.error(f"Error processing file {file_path}: {e}") | |
return docs | |
# Enhanced input validation | |
# Load NLTK word list | |
try: | |
english_words = set(words.words()) | |
except LookupError: | |
import nltk | |
nltk.download('words') | |
english_words = set(words.words()) | |
def is_valid_input(text): | |
"""Validate the user's input question.""" | |
if not text or text.strip() == "": | |
return False, "Input cannot be empty. Please provide a meaningful question." | |
if len(text.strip()) < 2: | |
return False, "Input is too short. Please provide more context or details." | |
# Check for valid words | |
words_in_text = re.findall(r'\b\w+\b', text.lower()) | |
recognized_words = [word for word in words_in_text if word in english_words] | |
if not recognized_words: | |
return False, "Input appears unclear. Please use valid words in your question." | |
return True, "Valid input." | |
def initialize_llm(model, temperature, max_tokens): | |
prompt_allocation = int(max_tokens * 0.2) | |
response_max_tokens = max_tokens - prompt_allocation | |
if response_max_tokens <= 50: | |
raise ValueError("max_tokens too small.") | |
llm = ChatGroq( | |
model=model, | |
temperature=temperature, | |
max_tokens=response_max_tokens, | |
api_key=api_key | |
) | |
return llm | |
def create_rag_pipeline(file_paths, model, temperature, max_tokens): | |
llm = initialize_llm(model, temperature, max_tokens) | |
docs = load_documents(file_paths) | |
if not docs: | |
return None, "No documents were loaded." | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vectorstore = Chroma.from_documents( | |
documents=splits, | |
embedding=embedding_model, | |
persist_directory="/tmp/chroma_db" | |
) | |
retriever = vectorstore.as_retriever() | |
custom_prompt_template = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are an AI assistant specialized in daily wellness. Provide a concise, thorough, and stand-alone answer to the user's question based on the given context. Include relevant examples or schedules where beneficial. **When listing steps or guidelines, format them as a numbered list with appropriate markdown formatting.** The final answer should be coherent, self-contained, and end with a complete sentence. | |
Context: | |
{context} | |
Question: | |
{question} | |
Final Answer: | |
""" | |
) | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs={"prompt": custom_prompt_template} | |
) | |
return rag_chain, "Pipeline created successfully." | |
file_paths = ['AIChatbot.csv'] | |
model = "llama3-8b-8192" | |
temperature = 0.7 | |
max_tokens = 500 | |
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) | |
def answer_question(model, temperature, max_tokens, question): | |
is_valid, message = is_valid_input(question) | |
if not is_valid: | |
return message | |
if rag_chain is None: | |
return "The system is currently unavailable. Please try again later." | |
try: | |
answer = rag_chain.run(question) | |
return answer.strip() | |
except Exception as e_inner: | |
logger.error(f"Error: {e_inner}") | |
return "An error occurred while processing your request." | |
def gradio_interface(model, temperature, max_tokens, question): | |
return answer_question(model, temperature, max_tokens, question) | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Model Name", value=model), | |
gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=temperature), | |
gr.Slider(label="Max Tokens", minimum=200, maximum=2048, step=1, value=max_tokens), | |
gr.Textbox(label="Question", placeholder="e.g., What is box breathing and how does it help reduce anxiety?") | |
], | |
outputs=gr.Markdown(label="Answer"), | |
title="Daily Wellness AI", | |
description="Ask questions about daily wellness and receive a concise, complete answer.", | |
examples=[ | |
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"], | |
["llama3-8b-8192", 0.6, 600, "Give me a weekly fitness schedule incorporating mindfulness exercises."] | |
], | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |