kidwaiaun commited on
Commit
54abf6b
·
verified ·
1 Parent(s): 162e1ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -25,7 +25,7 @@ model = AutoModelForCausalLM.from_pretrained(
25
  model_name,
26
  quantization_config=quantization_config,
27
  device_map="auto",
28
- max_memory={0: "22GiB", "cpu": "6GiB"} # Prevent VRAM overflow
29
  )
30
  text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
31
 
@@ -62,15 +62,15 @@ def start_new_session():
62
  def get_embedding(text):
63
  return embedding_model.encode(text, normalize_embeddings=True)
64
 
65
- def store_chat_in_session(user_input, response):
66
  if current_session_id is None:
67
  start_new_session()
68
- chat_sessions[current_session_id].append((user_input, response))
69
  chat_index.add(np.array([get_embedding(response)]))
70
 
71
  def get_recent_chat_history():
72
  if current_session_id in chat_sessions:
73
- return "\n".join([f"User: {q}\nAI: {r}" for q, r in chat_sessions[current_session_id]])
74
  return ""
75
 
76
  # Document Processing
@@ -93,8 +93,8 @@ def retrieve_relevant_passage(query, top_k=3):
93
  D, I = doc_index.search(np.array([query_embedding]), top_k)
94
  valid_indices = [i for i in I[0] if 0 <= i < len(doc_texts)]
95
  if valid_indices:
96
- return "\n".join([f"- {doc_texts[i]}" for i in valid_indices])
97
- return "No relevant document found."
98
 
99
  # Retrieve Chat Context
100
  def retrieve_chat_context(user_input, top_k=3):
@@ -111,8 +111,8 @@ def retrieve_chat_context(user_input, top_k=3):
111
  def chat_with_pdf(user_input, chat_history=[]):
112
  if not authenticated:
113
  return "Access Denied!", chat_history
114
- relevant_passage = retrieve_relevant_passage(user_input)
115
- past_chat_context = retrieve_chat_context(user_input)
116
  prompt = (
117
  "You are an HR assistant. Provide responses based on company policies. If unsure, say 'Please contact HR'.\n\n"
118
  f"Recent Chat:\n{past_chat_context}\nHR Policy Context:\n{relevant_passage}\nUser Inquiry: {user_input}\nAI Response:"
@@ -123,8 +123,10 @@ def chat_with_pdf(user_input, chat_history=[]):
123
  prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2,
124
  return_full_text=False
125
  )
126
- for token in response[0]['generated_text'].split():
127
- yield token + " "
 
 
128
 
129
  return response_generator(), chat_history
130
 
 
25
  model_name,
26
  quantization_config=quantization_config,
27
  device_map="auto",
28
+ max_memory={0: "22GiB", "cpu": "6GiB"} # Prevent VRAM overflow
29
  )
30
  text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
31
 
 
62
  def get_embedding(text):
63
  return embedding_model.encode(text, normalize_embeddings=True)
64
 
65
+ def store_chat_in_session(user_input, response, reference):
66
  if current_session_id is None:
67
  start_new_session()
68
+ chat_sessions[current_session_id].append((user_input, response, reference))
69
  chat_index.add(np.array([get_embedding(response)]))
70
 
71
  def get_recent_chat_history():
72
  if current_session_id in chat_sessions:
73
+ return "\n".join([f"User: {q}\nAI: {r}\nReference: {ref}" for q, r, ref in chat_sessions[current_session_id]])
74
  return ""
75
 
76
  # Document Processing
 
93
  D, I = doc_index.search(np.array([query_embedding]), top_k)
94
  valid_indices = [i for i in I[0] if 0 <= i < len(doc_texts)]
95
  if valid_indices:
96
+ return "\n".join([f"- {doc_texts[i]}" for i in valid_indices]), "\n".join([doc_texts[i] for i in valid_indices])
97
+ return "No relevant document found.", ""
98
 
99
  # Retrieve Chat Context
100
  def retrieve_chat_context(user_input, top_k=3):
 
111
  def chat_with_pdf(user_input, chat_history=[]):
112
  if not authenticated:
113
  return "Access Denied!", chat_history
114
+ relevant_passage, reference = retrieve_relevant_passage(user_input)
115
+ past_chat_context = get_recent_chat_history()
116
  prompt = (
117
  "You are an HR assistant. Provide responses based on company policies. If unsure, say 'Please contact HR'.\n\n"
118
  f"Recent Chat:\n{past_chat_context}\nHR Policy Context:\n{relevant_passage}\nUser Inquiry: {user_input}\nAI Response:"
 
123
  prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2,
124
  return_full_text=False
125
  )
126
+ answer = response[0]['generated_text'].split("AI Response:")[-1].strip()
127
+ store_chat_in_session(user_input, answer, reference)
128
+ formatted_response = f"{answer}\n\n*Reference:* _{reference}_"
129
+ yield formatted_response
130
 
131
  return response_generator(), chat_history
132