tareeb23 commited on
Commit
bf71408
·
verified ·
1 Parent(s): c9a0222

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -56
app.py CHANGED
@@ -1,27 +1,12 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  import re
4
- from collections import Counter
5
- import string
6
  import docx2txt
7
  from io import BytesIO
8
 
9
  @st.cache_resource
10
  def load_qa_pipeline():
11
- return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2")
12
-
13
- def normalize_answer(s):
14
- """Lower text and remove punctuation, articles and extra whitespace."""
15
- def remove_articles(text):
16
- return re.sub(r'\b(a|an|the)\b', ' ', text)
17
- def white_space_fix(text):
18
- return ' '.join(text.split())
19
- def remove_punc(text):
20
- exclude = set(string.punctuation)
21
- return ''.join(ch for ch in text if ch not in exclude)
22
- def lower(text):
23
- return text.lower()
24
- return white_space_fix(remove_articles(remove_punc(lower(s))))
25
 
26
  def chunk_text(text, chunk_size=1000):
27
  sentences = re.split(r'(?<=[.!?])\s+', text)
@@ -40,22 +25,24 @@ def chunk_text(text, chunk_size=1000):
40
 
41
  return chunks
42
 
43
- def highlight_text(text, start_indices, chunk_size):
44
- highlighted_text = text
45
- offset = 0
46
- for i, start in enumerate(start_indices):
47
- actual_start = start + (i * 7) # 7 is the length of the highlight tag
48
- chunk_index = start // chunk_size
49
- actual_start += chunk_index * chunk_size
50
- highlighted_text = (
51
- highlighted_text[:actual_start + offset] +
52
- "<mark>" +
53
- highlighted_text[actual_start + offset:actual_start + offset + 10] +
54
- "</mark>" +
55
- highlighted_text[actual_start + offset + 10:]
56
- )
57
- offset += 13 # Length of "<mark></mark>"
58
- return highlighted_text
 
 
59
 
60
  def main():
61
  st.title("Document Search Engine")
@@ -76,36 +63,34 @@ def main():
76
  st.session_state['context'] = context
77
 
78
  # Search input and button
79
- col1, col2 = st.columns([3, 1])
80
  with col1:
81
  question = st.text_input("Enter your search query:")
82
  with col2:
83
- search_button = st.button("Search")
 
 
84
 
85
- if search_button:
86
  if context and question:
87
- chunks = chunk_text(context)
88
- results = []
89
- for i, chunk in enumerate(chunks):
90
- result = qa_pipeline(question=question, context=chunk)
91
- result['chunk_index'] = i
92
- results.append(result)
93
-
94
- # Sort results by score and get top 3
95
- top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3]
96
-
97
- st.subheader("Top 3 Results:")
98
- for i, result in enumerate(top_results, 1):
99
- st.write(f"{i}. Answer: {result['answer']}")
100
- st.write(f" Confidence: {result['score']:.2f}")
101
-
102
- # Highlight answers in the context
103
- chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function
104
- start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results]
105
- highlighted_context = highlight_text(context, start_indices, chunk_size)
106
 
107
- st.subheader("Context with Highlighted Answers:")
108
- st.markdown(highlighted_context, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  else:
110
  st.warning("Please provide both context and search query.")
111
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  import re
 
 
4
  import docx2txt
5
  from io import BytesIO
6
 
7
  @st.cache_resource
8
  def load_qa_pipeline():
9
+ return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2")
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def chunk_text(text, chunk_size=1000):
12
  sentences = re.split(r'(?<=[.!?])\s+', text)
 
25
 
26
  return chunks
27
 
28
+ def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1):
29
+ chunks = chunk_text(context)
30
+ results = []
31
+
32
+ for i, chunk in enumerate(chunks):
33
+ result = qa_pipeline(question=question, context=chunk)
34
+ result['chunk_index'] = i
35
+ result['chunk_start'] = i * 1000 # Approximate start position in original context
36
+ results.append(result)
37
+
38
+ # Sort results by score, filter by score limit, and get top k
39
+ filtered_results = [r for r in results if r['score'] >= score_limit]
40
+ top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k]
41
+
42
+ return top_results
43
+
44
+ def highlight_answer(text, answer, start):
45
+ return text[:start] + "**" + answer + "**" + text[start+len(answer):]
46
 
47
  def main():
48
  st.title("Document Search Engine")
 
63
  st.session_state['context'] = context
64
 
65
  # Search input and button
66
+ col1, col2, col3 = st.columns([3, 1, 1])
67
  with col1:
68
  question = st.text_input("Enter your search query:")
69
  with col2:
70
+ top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3)
71
+ with col3:
72
+ score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05)
73
 
74
+ if st.button("Search"):
75
  if context and question:
76
+ top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ if top_results:
79
+ st.subheader(f"Top {len(top_results)} Results:")
80
+ for i, result in enumerate(top_results, 1):
81
+ st.write(f"{i}. Answer: {result['answer']}")
82
+ st.write(f" Confidence: {result['score']:.4f}")
83
+ st.write(f" Start Index in Original Context: {result['chunk_start'] + result['start']}")
84
+ st.write(f" Chunk Index: {result['chunk_index']}")
85
+
86
+ st.subheader("Context with Highlighted Answers:")
87
+ highlighted_context = context
88
+ for result in reversed(top_results): # Reverse to avoid messing up indices
89
+ start = result['chunk_start'] + result['start']
90
+ highlighted_context = highlight_answer(highlighted_context, result['answer'], start)
91
+ st.markdown(highlighted_context)
92
+ else:
93
+ st.warning("No results found above the score limit.")
94
  else:
95
  st.warning("Please provide both context and search query.")
96