Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
import re
|
4 |
-
from collections import Counter
|
5 |
-
import string
|
6 |
import docx2txt
|
7 |
from io import BytesIO
|
8 |
|
9 |
@st.cache_resource
|
10 |
def load_qa_pipeline():
|
11 |
-
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2")
|
12 |
-
|
13 |
-
def normalize_answer(s):
|
14 |
-
"""Lower text and remove punctuation, articles and extra whitespace."""
|
15 |
-
def remove_articles(text):
|
16 |
-
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
17 |
-
def white_space_fix(text):
|
18 |
-
return ' '.join(text.split())
|
19 |
-
def remove_punc(text):
|
20 |
-
exclude = set(string.punctuation)
|
21 |
-
return ''.join(ch for ch in text if ch not in exclude)
|
22 |
-
def lower(text):
|
23 |
-
return text.lower()
|
24 |
-
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
25 |
|
26 |
def chunk_text(text, chunk_size=1000):
|
27 |
sentences = re.split(r'(?<=[.!?])\s+', text)
|
@@ -40,22 +25,24 @@ def chunk_text(text, chunk_size=1000):
|
|
40 |
|
41 |
return chunks
|
42 |
|
43 |
-
def
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
|
60 |
def main():
|
61 |
st.title("Document Search Engine")
|
@@ -76,36 +63,34 @@ def main():
|
|
76 |
st.session_state['context'] = context
|
77 |
|
78 |
# Search input and button
|
79 |
-
col1, col2 = st.columns([3, 1])
|
80 |
with col1:
|
81 |
question = st.text_input("Enter your search query:")
|
82 |
with col2:
|
83 |
-
|
|
|
|
|
84 |
|
85 |
-
if
|
86 |
if context and question:
|
87 |
-
|
88 |
-
results = []
|
89 |
-
for i, chunk in enumerate(chunks):
|
90 |
-
result = qa_pipeline(question=question, context=chunk)
|
91 |
-
result['chunk_index'] = i
|
92 |
-
results.append(result)
|
93 |
-
|
94 |
-
# Sort results by score and get top 3
|
95 |
-
top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3]
|
96 |
-
|
97 |
-
st.subheader("Top 3 Results:")
|
98 |
-
for i, result in enumerate(top_results, 1):
|
99 |
-
st.write(f"{i}. Answer: {result['answer']}")
|
100 |
-
st.write(f" Confidence: {result['score']:.2f}")
|
101 |
-
|
102 |
-
# Highlight answers in the context
|
103 |
-
chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function
|
104 |
-
start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results]
|
105 |
-
highlighted_context = highlight_text(context, start_indices, chunk_size)
|
106 |
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
else:
|
110 |
st.warning("Please provide both context and search query.")
|
111 |
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
import re
|
|
|
|
|
4 |
import docx2txt
|
5 |
from io import BytesIO
|
6 |
|
7 |
@st.cache_resource
|
8 |
def load_qa_pipeline():
|
9 |
+
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def chunk_text(text, chunk_size=1000):
|
12 |
sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
|
25 |
|
26 |
return chunks
|
27 |
|
28 |
+
def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1):
|
29 |
+
chunks = chunk_text(context)
|
30 |
+
results = []
|
31 |
+
|
32 |
+
for i, chunk in enumerate(chunks):
|
33 |
+
result = qa_pipeline(question=question, context=chunk)
|
34 |
+
result['chunk_index'] = i
|
35 |
+
result['chunk_start'] = i * 1000 # Approximate start position in original context
|
36 |
+
results.append(result)
|
37 |
+
|
38 |
+
# Sort results by score, filter by score limit, and get top k
|
39 |
+
filtered_results = [r for r in results if r['score'] >= score_limit]
|
40 |
+
top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k]
|
41 |
+
|
42 |
+
return top_results
|
43 |
+
|
44 |
+
def highlight_answer(text, answer, start):
|
45 |
+
return text[:start] + "**" + answer + "**" + text[start+len(answer):]
|
46 |
|
47 |
def main():
|
48 |
st.title("Document Search Engine")
|
|
|
63 |
st.session_state['context'] = context
|
64 |
|
65 |
# Search input and button
|
66 |
+
col1, col2, col3 = st.columns([3, 1, 1])
|
67 |
with col1:
|
68 |
question = st.text_input("Enter your search query:")
|
69 |
with col2:
|
70 |
+
top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3)
|
71 |
+
with col3:
|
72 |
+
score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05)
|
73 |
|
74 |
+
if st.button("Search"):
|
75 |
if context and question:
|
76 |
+
top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
if top_results:
|
79 |
+
st.subheader(f"Top {len(top_results)} Results:")
|
80 |
+
for i, result in enumerate(top_results, 1):
|
81 |
+
st.write(f"{i}. Answer: {result['answer']}")
|
82 |
+
st.write(f" Confidence: {result['score']:.4f}")
|
83 |
+
st.write(f" Start Index in Original Context: {result['chunk_start'] + result['start']}")
|
84 |
+
st.write(f" Chunk Index: {result['chunk_index']}")
|
85 |
+
|
86 |
+
st.subheader("Context with Highlighted Answers:")
|
87 |
+
highlighted_context = context
|
88 |
+
for result in reversed(top_results): # Reverse to avoid messing up indices
|
89 |
+
start = result['chunk_start'] + result['start']
|
90 |
+
highlighted_context = highlight_answer(highlighted_context, result['answer'], start)
|
91 |
+
st.markdown(highlighted_context)
|
92 |
+
else:
|
93 |
+
st.warning("No results found above the score limit.")
|
94 |
else:
|
95 |
st.warning("Please provide both context and search query.")
|
96 |
|