Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
β’
101be32
1
Parent(s):
213d365
Adds keyword search and print pipeline
Browse files- app.py +3 -0
- core/pipelines.py +31 -2
- core/search_index.py +6 -4
- interface/components.py +9 -4
- interface/pages.py +5 -2
- requirements.txt +2 -1
app.py
CHANGED
@@ -5,6 +5,9 @@ st.set_page_config(
|
|
5 |
page_icon="π",
|
6 |
layout="wide",
|
7 |
initial_sidebar_state="expanded",
|
|
|
|
|
|
|
8 |
)
|
9 |
|
10 |
from streamlit_option_menu import option_menu
|
|
|
5 |
page_icon="π",
|
6 |
layout="wide",
|
7 |
initial_sidebar_state="expanded",
|
8 |
+
menu_items={
|
9 |
+
'About': "https://github.com/ugm2/neural-search-demo"
|
10 |
+
}
|
11 |
)
|
12 |
|
13 |
from streamlit_option_menu import option_menu
|
core/pipelines.py
CHANGED
@@ -5,10 +5,39 @@ Haystack Pipelines
|
|
5 |
import tokenizers
|
6 |
from haystack import Pipeline
|
7 |
from haystack.document_stores import InMemoryDocumentStore
|
8 |
-
from haystack.nodes.retriever import DensePassageRetriever
|
9 |
from haystack.nodes.preprocessor import PreProcessor
|
10 |
import streamlit as st
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None}, allow_output_mutation=True)
|
13 |
def dense_passage_retrieval(
|
14 |
index='documents',
|
@@ -42,4 +71,4 @@ def dense_passage_retrieval(
|
|
42 |
document_store, name="DocumentStore", inputs=["DPRRetriever"]
|
43 |
)
|
44 |
|
45 |
-
return search_pipeline, index_pipeline
|
|
|
5 |
import tokenizers
|
6 |
from haystack import Pipeline
|
7 |
from haystack.document_stores import InMemoryDocumentStore
|
8 |
+
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
9 |
from haystack.nodes.preprocessor import PreProcessor
|
10 |
import streamlit as st
|
11 |
|
12 |
+
@st.cache(allow_output_mutation=True)
|
13 |
+
def keyword_search(
|
14 |
+
index='documents',
|
15 |
+
):
|
16 |
+
document_store = InMemoryDocumentStore(index=index)
|
17 |
+
keyword_retriever = TfidfRetriever(document_store=(document_store))
|
18 |
+
processor = PreProcessor(
|
19 |
+
clean_empty_lines=True,
|
20 |
+
clean_whitespace=True,
|
21 |
+
clean_header_footer=True,
|
22 |
+
split_by="word",
|
23 |
+
split_length=100,
|
24 |
+
split_respect_sentence_boundary=True,
|
25 |
+
split_overlap=0,
|
26 |
+
)
|
27 |
+
# SEARCH PIPELINE
|
28 |
+
search_pipeline = Pipeline()
|
29 |
+
search_pipeline.add_node(keyword_retriever, name="TfidfRetriever", inputs=["Query"])
|
30 |
+
|
31 |
+
# INDEXING PIPELINE
|
32 |
+
index_pipeline = Pipeline()
|
33 |
+
index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
|
34 |
+
index_pipeline.add_node(keyword_retriever, name="TfidfRetriever", inputs=["Preprocessor"])
|
35 |
+
index_pipeline.add_node(
|
36 |
+
document_store, name="DocumentStore", inputs=["TfidfRetriever"]
|
37 |
+
)
|
38 |
+
|
39 |
+
return search_pipeline, index_pipeline
|
40 |
+
|
41 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None}, allow_output_mutation=True)
|
42 |
def dense_passage_retrieval(
|
43 |
index='documents',
|
|
|
71 |
document_store, name="DocumentStore", inputs=["DPRRetriever"]
|
72 |
)
|
73 |
|
74 |
+
return search_pipeline, index_pipeline
|
core/search_index.py
CHANGED
@@ -26,8 +26,10 @@ def search(queries, pipeline):
|
|
26 |
matches_queries = pipeline.run_batch(queries=queries)
|
27 |
for matches in matches_queries["documents"]:
|
28 |
query_results = []
|
|
|
29 |
for res in matches:
|
30 |
-
|
|
|
31 |
query_results.append(
|
32 |
{
|
33 |
"text": res.content,
|
@@ -36,7 +38,7 @@ def search(queries, pipeline):
|
|
36 |
"fragment_id": res.id
|
37 |
}
|
38 |
)
|
39 |
-
|
40 |
-
sorted(query_results, key=lambda x: x["score"], reverse=True)
|
41 |
-
)
|
42 |
return results
|
|
|
26 |
matches_queries = pipeline.run_batch(queries=queries)
|
27 |
for matches in matches_queries["documents"]:
|
28 |
query_results = []
|
29 |
+
score_is_empty = False
|
30 |
for res in matches:
|
31 |
+
if not score_is_empty:
|
32 |
+
score_is_empty = True if res.score is None else False
|
33 |
query_results.append(
|
34 |
{
|
35 |
"text": res.content,
|
|
|
38 |
"fragment_id": res.id
|
39 |
}
|
40 |
)
|
41 |
+
if not score_is_empty:
|
42 |
+
query_results = sorted(query_results, key=lambda x: x["score"], reverse=True)
|
43 |
+
results.append(query_results)
|
44 |
return results
|
interface/components.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import core.pipelines as pipelines_functions
|
3 |
from inspect import getmembers, isfunction
|
|
|
4 |
|
5 |
def component_select_pipeline(container):
|
6 |
pipeline_names, pipeline_funcs = list(zip(*getmembers(pipelines_functions, isfunction)))
|
@@ -8,7 +9,8 @@ def component_select_pipeline(container):
|
|
8 |
with container:
|
9 |
selected_pipeline = st.selectbox(
|
10 |
'Select pipeline',
|
11 |
-
pipeline_names
|
|
|
12 |
)
|
13 |
st.session_state['search_pipeline'], \
|
14 |
st.session_state['index_pipeline'] = \
|
@@ -16,8 +18,10 @@ def component_select_pipeline(container):
|
|
16 |
|
17 |
def component_show_pipeline(container, pipeline):
|
18 |
"""Draw the pipeline"""
|
19 |
-
with
|
20 |
-
|
|
|
|
|
21 |
|
22 |
def component_show_search_result(container, results):
|
23 |
with container:
|
@@ -25,7 +29,8 @@ def component_show_search_result(container, results):
|
|
25 |
st.markdown(f"### Match {idx+1}")
|
26 |
st.markdown(f"**Text**: {document['text']}")
|
27 |
st.markdown(f"**Document**: {document['id']}")
|
28 |
-
|
|
|
29 |
st.markdown("---")
|
30 |
|
31 |
def component_text_input(container):
|
|
|
1 |
import streamlit as st
|
2 |
import core.pipelines as pipelines_functions
|
3 |
from inspect import getmembers, isfunction
|
4 |
+
from networkx.drawing.nx_agraph import to_agraph
|
5 |
|
6 |
def component_select_pipeline(container):
|
7 |
pipeline_names, pipeline_funcs = list(zip(*getmembers(pipelines_functions, isfunction)))
|
|
|
9 |
with container:
|
10 |
selected_pipeline = st.selectbox(
|
11 |
'Select pipeline',
|
12 |
+
pipeline_names,
|
13 |
+
index=pipeline_names.index('Keyword Search') if 'Keyword Search' in pipeline_names else 0
|
14 |
)
|
15 |
st.session_state['search_pipeline'], \
|
16 |
st.session_state['index_pipeline'] = \
|
|
|
18 |
|
19 |
def component_show_pipeline(container, pipeline):
|
20 |
"""Draw the pipeline"""
|
21 |
+
with st.expander('Show pipeline'):
|
22 |
+
graphviz = to_agraph(pipeline.graph)
|
23 |
+
graphviz.layout("dot")
|
24 |
+
st.graphviz_chart(graphviz.string())
|
25 |
|
26 |
def component_show_search_result(container, results):
|
27 |
with container:
|
|
|
29 |
st.markdown(f"### Match {idx+1}")
|
30 |
st.markdown(f"**Text**: {document['text']}")
|
31 |
st.markdown(f"**Document**: {document['id']}")
|
32 |
+
if document['score'] is not None:
|
33 |
+
st.markdown(f"**Score**: {document['score']:.3f}")
|
34 |
st.markdown("---")
|
35 |
|
36 |
def component_text_input(container):
|
interface/pages.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit_option_menu import option_menu
|
3 |
from core.search_index import index, search
|
4 |
-
from interface.components import component_show_search_result, component_text_input
|
5 |
|
6 |
def page_landing_page(container):
|
7 |
with container:
|
@@ -18,7 +18,6 @@ def page_landing_page(container):
|
|
18 |
)
|
19 |
st.markdown(
|
20 |
"TODO list:"
|
21 |
-
"\n - Option to print pipeline structure on page"
|
22 |
"\n - Build other pipelines"
|
23 |
"\n - Include file/url indexing"
|
24 |
"\n - [Optional] Include text to audio to read responses"
|
@@ -31,6 +30,8 @@ def page_search(container):
|
|
31 |
## SEARCH ##
|
32 |
query = st.text_input("Query")
|
33 |
|
|
|
|
|
34 |
if st.button("Search"):
|
35 |
st.session_state['search_results'] = search(
|
36 |
queries=[query],
|
@@ -46,6 +47,8 @@ def page_index(container):
|
|
46 |
with container:
|
47 |
st.title("Index time!")
|
48 |
|
|
|
|
|
49 |
input_funcs = {
|
50 |
"Raw Text": (component_text_input, "card-text"),
|
51 |
}
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit_option_menu import option_menu
|
3 |
from core.search_index import index, search
|
4 |
+
from interface.components import component_show_pipeline, component_show_search_result, component_text_input
|
5 |
|
6 |
def page_landing_page(container):
|
7 |
with container:
|
|
|
18 |
)
|
19 |
st.markdown(
|
20 |
"TODO list:"
|
|
|
21 |
"\n - Build other pipelines"
|
22 |
"\n - Include file/url indexing"
|
23 |
"\n - [Optional] Include text to audio to read responses"
|
|
|
30 |
## SEARCH ##
|
31 |
query = st.text_input("Query")
|
32 |
|
33 |
+
component_show_pipeline(container, st.session_state['search_pipeline'])
|
34 |
+
|
35 |
if st.button("Search"):
|
36 |
st.session_state['search_results'] = search(
|
37 |
queries=[query],
|
|
|
47 |
with container:
|
48 |
st.title("Index time!")
|
49 |
|
50 |
+
component_show_pipeline(container, st.session_state['index_pipeline'])
|
51 |
+
|
52 |
input_funcs = {
|
53 |
"Raw Text": (component_text_input, "card-text"),
|
54 |
}
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
streamlit
|
2 |
streamlit_option_menu
|
3 |
-
farm-haystack
|
|
|
|
1 |
streamlit
|
2 |
streamlit_option_menu
|
3 |
+
farm-haystack
|
4 |
+
pygraphviz
|