Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
•
6c3736e
1
Parent(s):
e4aa90a
final touches to draw pipelines & manual cache
Browse files- core/pipelines.py +0 -10
- interface/components.py +10 -4
- interface/config.py +3 -1
- interface/draw_pipelines.py +31 -9
- interface/pages.py +4 -4
core/pipelines.py
CHANGED
@@ -2,15 +2,12 @@
|
|
2 |
Haystack Pipelines
|
3 |
"""
|
4 |
|
5 |
-
import tokenizers
|
6 |
from haystack import Pipeline
|
7 |
from haystack.document_stores import InMemoryDocumentStore
|
8 |
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
9 |
from haystack.nodes.preprocessor import PreProcessor
|
10 |
-
import streamlit as st
|
11 |
|
12 |
|
13 |
-
@st.cache(allow_output_mutation=True)
|
14 |
def keyword_search(
|
15 |
index="documents",
|
16 |
):
|
@@ -42,13 +39,6 @@ def keyword_search(
|
|
42 |
return search_pipeline, index_pipeline
|
43 |
|
44 |
|
45 |
-
@st.cache(
|
46 |
-
hash_funcs={
|
47 |
-
tokenizers.Tokenizer: lambda _: None,
|
48 |
-
tokenizers.AddedToken: lambda _: None,
|
49 |
-
},
|
50 |
-
allow_output_mutation=True,
|
51 |
-
)
|
52 |
def dense_passage_retrieval(
|
53 |
index="documents",
|
54 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
2 |
Haystack Pipelines
|
3 |
"""
|
4 |
|
|
|
5 |
from haystack import Pipeline
|
6 |
from haystack.document_stores import InMemoryDocumentStore
|
7 |
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
8 |
from haystack.nodes.preprocessor import PreProcessor
|
|
|
9 |
|
10 |
|
|
|
11 |
def keyword_search(
|
12 |
index="documents",
|
13 |
):
|
|
|
39 |
return search_pipeline, index_pipeline
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def dense_passage_retrieval(
|
43 |
index="documents",
|
44 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
interface/components.py
CHANGED
@@ -13,10 +13,16 @@ def component_select_pipeline(container):
|
|
13 |
if "Keyword Search" in pipeline_names
|
14 |
else 0,
|
15 |
)
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def component_show_pipeline(pipeline):
|
|
|
13 |
if "Keyword Search" in pipeline_names
|
14 |
else 0,
|
15 |
)
|
16 |
+
if st.session_state["pipeline"] is None or st.session_state["pipeline"]["name"] != selected_pipeline:
|
17 |
+
(
|
18 |
+
search_pipeline,
|
19 |
+
index_pipeline,
|
20 |
+
) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
|
21 |
+
st.session_state["pipeline"] = {
|
22 |
+
'name': selected_pipeline,
|
23 |
+
'search_pipeline': search_pipeline,
|
24 |
+
'index_pipeline': index_pipeline,
|
25 |
+
}
|
26 |
|
27 |
|
28 |
def component_show_pipeline(pipeline):
|
interface/config.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from interface.pages import page_landing_page, page_search, page_index
|
2 |
|
3 |
# Define default Session Variables over the whole session.
|
4 |
-
session_state_variables = {
|
|
|
|
|
5 |
|
6 |
# Define Pages for the demo
|
7 |
pages = {
|
|
|
1 |
from interface.pages import page_landing_page, page_search, page_index
|
2 |
|
3 |
# Define default Session Variables over the whole session.
|
4 |
+
session_state_variables = {
|
5 |
+
"pipeline": None
|
6 |
+
}
|
7 |
|
8 |
# Define Pages for the demo
|
9 |
pages = {
|
interface/draw_pipelines.py
CHANGED
@@ -3,11 +3,9 @@ from typing import List
|
|
3 |
from itertools import chain
|
4 |
import networkx as nx
|
5 |
import plotly.graph_objs as go
|
6 |
-
import streamlit as st
|
7 |
import numpy as np
|
8 |
|
9 |
|
10 |
-
@st.cache(allow_output_mutation=True)
|
11 |
def get_pipeline_graph(pipeline):
|
12 |
# Controls for how the graph is drawn
|
13 |
nodeColor = "#ffbf00"
|
@@ -16,13 +14,37 @@ def get_pipeline_graph(pipeline):
|
|
16 |
lineColor = "#ffffff"
|
17 |
|
18 |
G = pipeline.graph
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
for node in G.nodes:
|
27 |
G.nodes[node]["pos"] = list(pos[node])
|
28 |
|
|
|
3 |
from itertools import chain
|
4 |
import networkx as nx
|
5 |
import plotly.graph_objs as go
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
|
|
|
9 |
def get_pipeline_graph(pipeline):
|
10 |
# Controls for how the graph is drawn
|
11 |
nodeColor = "#ffbf00"
|
|
|
14 |
lineColor = "#ffffff"
|
15 |
|
16 |
G = pipeline.graph
|
17 |
+
current_coordinate = (0, len(set([edge[0] for edge in G.edges()])) + 1)
|
18 |
+
# Transform G.edges into {node : all_connected_nodes} format
|
19 |
+
node_connections = {}
|
20 |
+
for in_node, out_node in G.edges():
|
21 |
+
if in_node in node_connections:
|
22 |
+
node_connections[in_node].append(out_node)
|
23 |
+
else:
|
24 |
+
node_connections[in_node] = [out_node]
|
25 |
+
# Get node coordinates/pos
|
26 |
+
fixed_pos_nodes = {}
|
27 |
+
for idx, (in_node, out_nodes) in enumerate(node_connections.items()):
|
28 |
+
if in_node not in fixed_pos_nodes:
|
29 |
+
fixed_pos_nodes[in_node] = np.array([current_coordinate[0], current_coordinate[1]])
|
30 |
+
current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
|
31 |
+
# If more than 1 out node, then branch out in X coordinate
|
32 |
+
if len(out_nodes) > 1:
|
33 |
+
# if length is odd
|
34 |
+
if (len(out_nodes) % 2) != 0:
|
35 |
+
middle_node = out_nodes[round(len(out_nodes)/2, 0) - 1]
|
36 |
+
fixed_pos_nodes[middle_node] = np.array([current_coordinate[0], current_coordinate[1]])
|
37 |
+
out_nodes = [n for n in out_nodes if n != middle_node]
|
38 |
+
correction_coordinate = - len(out_nodes) / 2
|
39 |
+
for out_node in out_nodes:
|
40 |
+
fixed_pos_nodes[out_node] = np.array([int(current_coordinate[0] + correction_coordinate), int(current_coordinate[1])])
|
41 |
+
if correction_coordinate == -1:
|
42 |
+
correction_coordinate += 1
|
43 |
+
correction_coordinate += 1
|
44 |
+
current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
|
45 |
+
elif len(node_connections) - 1 == idx:
|
46 |
+
fixed_pos_nodes[out_nodes[0]] = np.array([current_coordinate[0], current_coordinate[1]])
|
47 |
+
pos = nx.spring_layout(G, pos=fixed_pos_nodes, fixed=G.nodes(), seed=42)
|
48 |
for node in G.nodes:
|
49 |
G.nodes[node]["pos"] = list(pos[node])
|
50 |
|
interface/pages.py
CHANGED
@@ -36,12 +36,12 @@ def page_search(container):
|
|
36 |
## SEARCH ##
|
37 |
query = st.text_input("Query")
|
38 |
|
39 |
-
component_show_pipeline(st.session_state["search_pipeline"])
|
40 |
|
41 |
if st.button("Search"):
|
42 |
st.session_state["search_results"] = search(
|
43 |
queries=[query],
|
44 |
-
pipeline=st.session_state["search_pipeline"],
|
45 |
)
|
46 |
if "search_results" in st.session_state:
|
47 |
component_show_search_result(
|
@@ -53,7 +53,7 @@ def page_index(container):
|
|
53 |
with container:
|
54 |
st.title("Index time!")
|
55 |
|
56 |
-
component_show_pipeline(st.session_state["index_pipeline"])
|
57 |
|
58 |
input_funcs = {
|
59 |
"Raw Text": (component_text_input, "card-text"),
|
@@ -74,7 +74,7 @@ def page_index(container):
|
|
74 |
if st.button("Index"):
|
75 |
index_results = index(
|
76 |
corpus,
|
77 |
-
st.session_state["index_pipeline"],
|
78 |
)
|
79 |
if index_results:
|
80 |
st.write(index_results)
|
|
|
36 |
## SEARCH ##
|
37 |
query = st.text_input("Query")
|
38 |
|
39 |
+
component_show_pipeline(st.session_state["pipeline"]["search_pipeline"])
|
40 |
|
41 |
if st.button("Search"):
|
42 |
st.session_state["search_results"] = search(
|
43 |
queries=[query],
|
44 |
+
pipeline=st.session_state["pipeline"]["search_pipeline"],
|
45 |
)
|
46 |
if "search_results" in st.session_state:
|
47 |
component_show_search_result(
|
|
|
53 |
with container:
|
54 |
st.title("Index time!")
|
55 |
|
56 |
+
component_show_pipeline(st.session_state["pipeline"]["index_pipeline"])
|
57 |
|
58 |
input_funcs = {
|
59 |
"Raw Text": (component_text_input, "card-text"),
|
|
|
74 |
if st.button("Index"):
|
75 |
index_results = index(
|
76 |
corpus,
|
77 |
+
st.session_state["pipeline"]["index_pipeline"],
|
78 |
)
|
79 |
if index_results:
|
80 |
st.write(index_results)
|