HRT / app.py
kidwaiaun's picture
Update app.py
54abf6b verified
import os
import faiss
import numpy as np
import gradio as gr
import PyPDF2
import uuid
from collections import deque
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
# Authentication
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
# Load AI Model
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype="float16",
bnb_4bit_use_double_quant=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
max_memory={0: "22GiB", "cpu": "6GiB"} # Prevent VRAM overflow
)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Sentence Embedding Model
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# FAISS Indexes with HNSW for optimized retrieval
embedding_dim = 384
doc_index = faiss.IndexHNSWFlat(embedding_dim, 32) # More efficient than IndexFlatL2
chat_index = faiss.IndexHNSWFlat(embedding_dim, 32)
doc_texts = []
# Session-based Memory
chat_sessions = {}
current_session_id = None
SESSION_HISTORY_LIMIT = 5
# Authentication
SECRET_PASSWORD = os.getenv("APP_SECRET_PASSWORD")
authenticated = False
def verify_password(password):
global authenticated
authenticated = password == SECRET_PASSWORD
return "Access Granted!" if authenticated else "Invalid Password!"
# Chat Session Management
def start_new_session():
global current_session_id, chat_sessions
current_session_id = str(uuid.uuid4())
chat_sessions[current_session_id] = deque(maxlen=SESSION_HISTORY_LIMIT)
return current_session_id
def get_embedding(text):
return embedding_model.encode(text, normalize_embeddings=True)
def store_chat_in_session(user_input, response, reference):
if current_session_id is None:
start_new_session()
chat_sessions[current_session_id].append((user_input, response, reference))
chat_index.add(np.array([get_embedding(response)]))
def get_recent_chat_history():
if current_session_id in chat_sessions:
return "\n".join([f"User: {q}\nAI: {r}\nReference: {ref}" for q, r, ref in chat_sessions[current_session_id]])
return ""
# Document Processing
def process_pdf(pdf_file):
if not authenticated:
return "Access Denied!"
pdf_reader = PyPDF2.PdfReader(pdf_file)
document_text = " ".join([page.extract_text().replace("\n", " ") for page in pdf_reader.pages if page.extract_text()])
text_chunks = document_text.split(". ")
embeddings = np.array([get_embedding(chunk) for chunk in text_chunks])
doc_index.add(embeddings)
doc_texts.extend(text_chunks)
return "Doc Processed."
# Retrieve Relevant HR Policy Passages
def retrieve_relevant_passage(query, top_k=3):
if not authenticated:
return "Access Denied!"
query_embedding = get_embedding(query)
D, I = doc_index.search(np.array([query_embedding]), top_k)
valid_indices = [i for i in I[0] if 0 <= i < len(doc_texts)]
if valid_indices:
return "\n".join([f"- {doc_texts[i]}" for i in valid_indices]), "\n".join([doc_texts[i] for i in valid_indices])
return "No relevant document found.", ""
# Retrieve Chat Context
def retrieve_chat_context(user_input, top_k=3):
if not authenticated:
return ""
query_embedding = get_embedding(user_input)
retrieved_texts = []
if chat_index.ntotal > 0:
D, I = chat_index.search(np.array([query_embedding]), top_k)
retrieved_texts = [chat_sessions[current_session_id][i][1] for i in I[0] if i < len(chat_sessions[current_session_id])]
return f"{get_recent_chat_history()}\n{''.join(retrieved_texts)}"
# AI Chatbot with Streaming
def chat_with_pdf(user_input, chat_history=[]):
if not authenticated:
return "Access Denied!", chat_history
relevant_passage, reference = retrieve_relevant_passage(user_input)
past_chat_context = get_recent_chat_history()
prompt = (
"You are an HR assistant. Provide responses based on company policies. If unsure, say 'Please contact HR'.\n\n"
f"Recent Chat:\n{past_chat_context}\nHR Policy Context:\n{relevant_passage}\nUser Inquiry: {user_input}\nAI Response:"
)
def response_generator():
response = text_generator(
prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2,
return_full_text=False
)
answer = response[0]['generated_text'].split("AI Response:")[-1].strip()
store_chat_in_session(user_input, answer, reference)
formatted_response = f"{answer}\n\n*Reference:* _{reference}_"
yield formatted_response
return response_generator(), chat_history
# Gradio Interface
with gr.Blocks() as chat_ui:
gr.Markdown("# πŸ“„ HR-Talk")
with gr.Accordion("Authenticator", open=False):
password_input = gr.Textbox(placeholder="Enter Password", type="password", interactive=True, scale=3, show_label=False)
verify_button = gr.Button("βœ… Verify", variant="primary", scale=1)
access_status = gr.Label(value="Status", scale=2)
verify_button.click(verify_password, inputs=[password_input], outputs=[access_status])
with gr.Accordion("Document Feeder", open=False):
file_upload = gr.File(label="πŸ“‚ Upload PDF", file_types=[".pdf"], interactive=True, scale=5)
upload_btn = gr.Button("πŸ“€ Process PDF", variant="primary", scale=2)
status = gr.Label(value="Waiting for upload...", scale=3)
upload_btn.click(process_pdf, inputs=[file_upload], outputs=[status])
chatbot = gr.Chatbot()
user_input = gr.Textbox(placeholder="Type your message...", show_label=False, scale=8)
send_btn = gr.Button("Send", scale=2)
def stream_response(user_input, chat_history):
response_generator, chat_history = chat_with_pdf(user_input, chat_history)
full_response = ""
for word in response_generator:
full_response += word
yield chat_history[:-1] + [(user_input, full_response)]
chat_history.append((user_input, full_response))
yield chat_history
send_btn.click(stream_response, inputs=[user_input, chatbot], outputs=[chatbot])
if __name__ == "__main__":
chat_ui.launch()