File size: 6,528 Bytes
c8cb9d6
b1961ae
 
 
ed57ecb
a37eeae
 
2ece135
b1961ae
 
24a8f30
a37eeae
e8e9c30
 
a37eeae
0bbcf47
06733dd
 
 
 
 
 
 
 
 
 
 
54abf6b
06733dd
e5fd7ce
b1961ae
a37eeae
b1961ae
 
06733dd
a37eeae
06733dd
 
6b7d3a2
a938360
06733dd
a37eeae
 
 
b6b9d59
a37eeae
 
 
 
a938360
b6b9d59
a37eeae
 
 
06733dd
a37eeae
 
 
06733dd
a37eeae
 
06733dd
 
 
54abf6b
a37eeae
 
54abf6b
a37eeae
 
 
 
54abf6b
a37eeae
 
 
ed57ecb
b6b9d59
a37eeae
ed57ecb
a37eeae
fc49792
 
 
 
613f883
a37eeae
 
5f131af
b6b9d59
a37eeae
ed57ecb
a37eeae
06733dd
fc49792
54abf6b
 
6349eab
a37eeae
6349eab
b6b9d59
 
6349eab
a37eeae
 
 
 
06733dd
a37eeae
06733dd
c4ad1eb
b6b9d59
a37eeae
54abf6b
 
e5fd7ce
06733dd
 
c4ad1eb
6349eab
a37eeae
06733dd
 
 
 
54abf6b
 
 
 
e5fd7ce
a37eeae
6349eab
06733dd
200cb19
29f484e
 
06733dd
 
 
4558bdd
06733dd
29f484e
06733dd
 
 
 
a938360
29f484e
06733dd
 
612b08d
 
 
71eb209
 
a816555
06733dd
4295b82
484cd86
612b08d
71eb209
24a8f30
 
06f65fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import faiss
import numpy as np
import gradio as gr
import PyPDF2
import uuid
from collections import deque
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from huggingface_hub import login

# Authentication
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))

# Load AI Model
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    max_memory={0: "22GiB", "cpu": "6GiB"}  # Prevent VRAM overflow
)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Sentence Embedding Model
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# FAISS Indexes with HNSW for optimized retrieval
embedding_dim = 384
doc_index = faiss.IndexHNSWFlat(embedding_dim, 32)  # More efficient than IndexFlatL2
chat_index = faiss.IndexHNSWFlat(embedding_dim, 32)
doc_texts = []

# Session-based Memory
chat_sessions = {}
current_session_id = None
SESSION_HISTORY_LIMIT = 5

# Authentication
SECRET_PASSWORD = os.getenv("APP_SECRET_PASSWORD")
authenticated = False

def verify_password(password):
    global authenticated
    authenticated = password == SECRET_PASSWORD
    return "Access Granted!" if authenticated else "Invalid Password!"

# Chat Session Management
def start_new_session():
    global current_session_id, chat_sessions
    current_session_id = str(uuid.uuid4())
    chat_sessions[current_session_id] = deque(maxlen=SESSION_HISTORY_LIMIT)
    return current_session_id

def get_embedding(text):
    return embedding_model.encode(text, normalize_embeddings=True)

def store_chat_in_session(user_input, response, reference):
    if current_session_id is None:
        start_new_session()
    chat_sessions[current_session_id].append((user_input, response, reference))
    chat_index.add(np.array([get_embedding(response)]))

def get_recent_chat_history():
    if current_session_id in chat_sessions:
        return "\n".join([f"User: {q}\nAI: {r}\nReference: {ref}" for q, r, ref in chat_sessions[current_session_id]])
    return ""

# Document Processing
def process_pdf(pdf_file):
    if not authenticated:
        return "Access Denied!"
    pdf_reader = PyPDF2.PdfReader(pdf_file)
    document_text = " ".join([page.extract_text().replace("\n", " ") for page in pdf_reader.pages if page.extract_text()])
    text_chunks = document_text.split(". ")
    embeddings = np.array([get_embedding(chunk) for chunk in text_chunks])
    doc_index.add(embeddings)
    doc_texts.extend(text_chunks)
    return "Doc Processed."

# Retrieve Relevant HR Policy Passages
def retrieve_relevant_passage(query, top_k=3):
    if not authenticated:
        return "Access Denied!"
    query_embedding = get_embedding(query)
    D, I = doc_index.search(np.array([query_embedding]), top_k)
    valid_indices = [i for i in I[0] if 0 <= i < len(doc_texts)]
    if valid_indices:
        return "\n".join([f"- {doc_texts[i]}" for i in valid_indices]), "\n".join([doc_texts[i] for i in valid_indices])
    return "No relevant document found.", ""

# Retrieve Chat Context
def retrieve_chat_context(user_input, top_k=3):
    if not authenticated:
        return ""
    query_embedding = get_embedding(user_input)
    retrieved_texts = []
    if chat_index.ntotal > 0:
        D, I = chat_index.search(np.array([query_embedding]), top_k)
        retrieved_texts = [chat_sessions[current_session_id][i][1] for i in I[0] if i < len(chat_sessions[current_session_id])]
    return f"{get_recent_chat_history()}\n{''.join(retrieved_texts)}"

# AI Chatbot with Streaming
def chat_with_pdf(user_input, chat_history=[]):
    if not authenticated:
        return "Access Denied!", chat_history
    relevant_passage, reference = retrieve_relevant_passage(user_input)
    past_chat_context = get_recent_chat_history()
    prompt = (
        "You are an HR assistant. Provide responses based on company policies. If unsure, say 'Please contact HR'.\n\n"
        f"Recent Chat:\n{past_chat_context}\nHR Policy Context:\n{relevant_passage}\nUser Inquiry: {user_input}\nAI Response:"
    )
    
    def response_generator():
        response = text_generator(
            prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2,
            return_full_text=False
        )
        answer = response[0]['generated_text'].split("AI Response:")[-1].strip()
        store_chat_in_session(user_input, answer, reference)
        formatted_response = f"{answer}\n\n*Reference:* _{reference}_"
        yield formatted_response
    
    return response_generator(), chat_history

# Gradio Interface
with gr.Blocks() as chat_ui:
    gr.Markdown("# πŸ“„ HR-Talk")
    with gr.Accordion("Authenticator", open=False):
        password_input = gr.Textbox(placeholder="Enter Password", type="password", interactive=True, scale=3, show_label=False)
        verify_button = gr.Button("βœ… Verify", variant="primary", scale=1)
        access_status = gr.Label(value="Status", scale=2)
        verify_button.click(verify_password, inputs=[password_input], outputs=[access_status])
    
    with gr.Accordion("Document Feeder", open=False):
        file_upload = gr.File(label="πŸ“‚ Upload PDF", file_types=[".pdf"], interactive=True, scale=5)
        upload_btn = gr.Button("πŸ“€ Process PDF", variant="primary", scale=2)
        status = gr.Label(value="Waiting for upload...", scale=3)
        upload_btn.click(process_pdf, inputs=[file_upload], outputs=[status])
    
    chatbot = gr.Chatbot()
    user_input = gr.Textbox(placeholder="Type your message...", show_label=False, scale=8)
    send_btn = gr.Button("Send", scale=2)
    
    def stream_response(user_input, chat_history):
        response_generator, chat_history = chat_with_pdf(user_input, chat_history)
        full_response = ""
        for word in response_generator:
            full_response += word
            yield chat_history[:-1] + [(user_input, full_response)]
        chat_history.append((user_input, full_response))
        yield chat_history
    
    send_btn.click(stream_response, inputs=[user_input, chatbot], outputs=[chatbot])

if __name__ == "__main__":
    chat_ui.launch()