ghjkl / app.py
Srinivasulu kethanaboina
Update app.py
52bd5b0 verified
raw
history blame
5.26 kB
from dotenv import load_dotenv
import gradio as gr
import os
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sentence_transformers import SentenceTransformer
import csv
import os
PERSIST_DIR = "history" # Replace with your actual directory path
CSV_FILE = os.path.join(PERSIST_DIR, "chat_history.csv")
# Load environment variables
load_dotenv()
# Configure the Llama index settings
Settings.llm = HuggingFaceInferenceAPI(
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
context_window=3000,
token=os.getenv("HF_TOKEN"),
max_new_tokens=512,
generate_kwargs={"temperature": 0.1},
)
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
# Define the directory for persistent storage and data
PERSIST_DIR = "db"
PDF_DIRECTORY = 'data' # Changed to the directory containing PDFs
# Ensure directories exist
os.makedirs(PDF_DIRECTORY, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)
# Variable to store current chat conversation
current_chat_history = []
def data_ingestion_from_directory():
# Use SimpleDirectoryReader on the directory containing the PDF files
documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
storage_context = StorageContext.from_defaults()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=PERSIST_DIR)
def handle_query(query):
# Ensure the directory exists or create it
os.makedirs(PERSIST_DIR, exist_ok=True)
chat_text_qa_msgs = [
(
"user",
"""
As FernAI, your goal is to offer top-tier service and information about RedFerns Tech company.
Provide concise answers based on the conversation flow. Ultimately, aim to attract users to connect with our services.
Summarize responses effectively in 20-60 words without unnecessary repetition.
{context_str}
Question:
{query_str}
"""
)
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
# Load index from storage
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
# Use chat history to enhance response (assuming current_chat_history is defined)
context_str = ""
for past_query, response in reversed(current_chat_history):
if past_query.strip():
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
answer = query_engine.query(query)
if hasattr(answer, 'response'):
response = answer.response
elif isinstance(answer, dict) and 'response' in answer:
response = answer['response']
else:
response = "Sorry, I couldn't find an answer."
# Update current chat history
current_chat_history.append((query, response))
# Save chat history to CSV
with open(CSV_FILE, 'a', newline='', encoding='utf-8') as file:
csv_writer = csv.writer(file)
csv_writer.writerow([query, response])
return response
# Example usage: Process PDF ingestion from directory
print("Processing PDF ingestion from directory:", PDF_DIRECTORY)
data_ingestion_from_directory()
# Define the function to handle predictions
"""def predict(message,history):
response = handle_query(message)
return response"""
def predict(message, history):
# Your logo HTML code
logo_html = '''
<div class="circle-logo">
<img src="https://rb.gy/8r06eg" alt="FernAi">
</div>
'''
# Assuming handle_query function handles the message and returns a response
response = handle_query(message)
# Prepare the response with logo HTML
response_with_logo = f'<div class="response-with-logo">{logo_html}<div class="response-text">{response}</div></div>'
# Convert history to a string (if it's a list)
if isinstance(history, list):
history = ' '.join(map(str, history))
# Save history to kk.txt
with open('kk.txt', 'a') as file:
file.write(history + '\n')
return response_with_logo
# Custom CSS for styling
css = '''
.circle-logo {
display: inline-block;
width: 40px;
height: 40px;
border-radius: 50%;
overflow: hidden;
margin-right: 10px;
vertical-align: middle;
}
.circle-logo img {
width: 100%;
height: 100%;
object-fit: cover;
}
.response-with-logo {
display: flex;
align-items: center;
margin-bottom: 10px;
}
footer {
display: none !important;
background-color: #F8D7DA;
}
label.svelte-1b6s6s {display: none}
'''
gr.ChatInterface(predict,
css=css,
description="FernAI",
clear_btn=None, undo_btn=None, retry_btn=None,
examples=['Tell me about Redfernstech?', 'Services in Redfernstech?']
).launch(share = False)