Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
25 |
model_name,
|
26 |
quantization_config=quantization_config,
|
27 |
device_map="auto",
|
28 |
-
max_memory={0: "22GiB", "cpu": "6GiB"} # Prevent VRAM overflow
|
29 |
)
|
30 |
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
31 |
|
@@ -62,15 +62,15 @@ def start_new_session():
|
|
62 |
def get_embedding(text):
|
63 |
return embedding_model.encode(text, normalize_embeddings=True)
|
64 |
|
65 |
-
def store_chat_in_session(user_input, response):
|
66 |
if current_session_id is None:
|
67 |
start_new_session()
|
68 |
-
chat_sessions[current_session_id].append((user_input, response))
|
69 |
chat_index.add(np.array([get_embedding(response)]))
|
70 |
|
71 |
def get_recent_chat_history():
|
72 |
if current_session_id in chat_sessions:
|
73 |
-
return "\n".join([f"User: {q}\nAI: {r}" for q, r in chat_sessions[current_session_id]])
|
74 |
return ""
|
75 |
|
76 |
# Document Processing
|
@@ -93,8 +93,8 @@ def retrieve_relevant_passage(query, top_k=3):
|
|
93 |
D, I = doc_index.search(np.array([query_embedding]), top_k)
|
94 |
valid_indices = [i for i in I[0] if 0 <= i < len(doc_texts)]
|
95 |
if valid_indices:
|
96 |
-
return "\n".join([f"- {doc_texts[i]}" for i in valid_indices])
|
97 |
-
return "No relevant document found."
|
98 |
|
99 |
# Retrieve Chat Context
|
100 |
def retrieve_chat_context(user_input, top_k=3):
|
@@ -111,8 +111,8 @@ def retrieve_chat_context(user_input, top_k=3):
|
|
111 |
def chat_with_pdf(user_input, chat_history=[]):
|
112 |
if not authenticated:
|
113 |
return "Access Denied!", chat_history
|
114 |
-
relevant_passage = retrieve_relevant_passage(user_input)
|
115 |
-
past_chat_context =
|
116 |
prompt = (
|
117 |
"You are an HR assistant. Provide responses based on company policies. If unsure, say 'Please contact HR'.\n\n"
|
118 |
f"Recent Chat:\n{past_chat_context}\nHR Policy Context:\n{relevant_passage}\nUser Inquiry: {user_input}\nAI Response:"
|
@@ -123,8 +123,10 @@ def chat_with_pdf(user_input, chat_history=[]):
|
|
123 |
prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2,
|
124 |
return_full_text=False
|
125 |
)
|
126 |
-
|
127 |
-
|
|
|
|
|
128 |
|
129 |
return response_generator(), chat_history
|
130 |
|
|
|
25 |
model_name,
|
26 |
quantization_config=quantization_config,
|
27 |
device_map="auto",
|
28 |
+
max_memory={0: "22GiB", "cpu": "6GiB"} # Prevent VRAM overflow
|
29 |
)
|
30 |
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
31 |
|
|
|
62 |
def get_embedding(text):
|
63 |
return embedding_model.encode(text, normalize_embeddings=True)
|
64 |
|
65 |
+
def store_chat_in_session(user_input, response, reference):
|
66 |
if current_session_id is None:
|
67 |
start_new_session()
|
68 |
+
chat_sessions[current_session_id].append((user_input, response, reference))
|
69 |
chat_index.add(np.array([get_embedding(response)]))
|
70 |
|
71 |
def get_recent_chat_history():
|
72 |
if current_session_id in chat_sessions:
|
73 |
+
return "\n".join([f"User: {q}\nAI: {r}\nReference: {ref}" for q, r, ref in chat_sessions[current_session_id]])
|
74 |
return ""
|
75 |
|
76 |
# Document Processing
|
|
|
93 |
D, I = doc_index.search(np.array([query_embedding]), top_k)
|
94 |
valid_indices = [i for i in I[0] if 0 <= i < len(doc_texts)]
|
95 |
if valid_indices:
|
96 |
+
return "\n".join([f"- {doc_texts[i]}" for i in valid_indices]), "\n".join([doc_texts[i] for i in valid_indices])
|
97 |
+
return "No relevant document found.", ""
|
98 |
|
99 |
# Retrieve Chat Context
|
100 |
def retrieve_chat_context(user_input, top_k=3):
|
|
|
111 |
def chat_with_pdf(user_input, chat_history=[]):
|
112 |
if not authenticated:
|
113 |
return "Access Denied!", chat_history
|
114 |
+
relevant_passage, reference = retrieve_relevant_passage(user_input)
|
115 |
+
past_chat_context = get_recent_chat_history()
|
116 |
prompt = (
|
117 |
"You are an HR assistant. Provide responses based on company policies. If unsure, say 'Please contact HR'.\n\n"
|
118 |
f"Recent Chat:\n{past_chat_context}\nHR Policy Context:\n{relevant_passage}\nUser Inquiry: {user_input}\nAI Response:"
|
|
|
123 |
prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2,
|
124 |
return_full_text=False
|
125 |
)
|
126 |
+
answer = response[0]['generated_text'].split("AI Response:")[-1].strip()
|
127 |
+
store_chat_in_session(user_input, answer, reference)
|
128 |
+
formatted_response = f"{answer}\n\n*Reference:* _{reference}_"
|
129 |
+
yield formatted_response
|
130 |
|
131 |
return response_generator(), chat_history
|
132 |
|