ugmSorcero commited on
Commit
dd7488f
1 Parent(s): 2ce16ea

Starts working in new drawing pipeline method

Browse files
interface/components.py CHANGED
@@ -1,16 +1,10 @@
1
  import streamlit as st
2
- import core.pipelines as pipelines_functions
3
- from inspect import getmembers, isfunction
4
- from networkx.drawing.nx_agraph import to_agraph
5
 
6
 
7
  def component_select_pipeline(container):
8
- pipeline_names, pipeline_funcs = list(
9
- zip(*getmembers(pipelines_functions, isfunction))
10
- )
11
- pipeline_names = [
12
- " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
13
- ]
14
  with container:
15
  selected_pipeline = st.selectbox(
16
  "Select pipeline",
@@ -25,12 +19,11 @@ def component_select_pipeline(container):
25
  ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
26
 
27
 
28
- def component_show_pipeline(container, pipeline):
29
  """Draw the pipeline"""
30
  with st.expander("Show pipeline"):
31
- graphviz = to_agraph(pipeline.graph)
32
- graphviz.layout("dot")
33
- st.graphviz_chart(graphviz.string())
34
 
35
 
36
  def component_show_search_result(container, results):
 
1
  import streamlit as st
2
+ from interface.utils import get_pipelines
3
+ from interface.draw_pipelines import get_pipeline_graph
 
4
 
5
 
6
  def component_select_pipeline(container):
7
+ pipeline_names, pipeline_funcs = get_pipelines()
 
 
 
 
 
8
  with container:
9
  selected_pipeline = st.selectbox(
10
  "Select pipeline",
 
19
  ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
20
 
21
 
22
+ def component_show_pipeline(pipeline):
23
  """Draw the pipeline"""
24
  with st.expander("Show pipeline"):
25
+ fig = get_pipeline_graph(pipeline)
26
+ st.plotly_chart(fig, use_container_width=True)
 
27
 
28
 
29
  def component_show_search_result(container, results):
interface/draw_pipelines.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from itertools import chain
4
+ import networkx as nx
5
+ import plotly.graph_objs as go
6
+ import streamlit as st
7
+
8
+ # Start and end are lists defining start and end points
9
+ # Edge x and y are lists used to construct the graph
10
+ # arrowAngle and arrowLength define properties of the arrowhead
11
+ # arrowPos is None, 'middle' or 'end' based on where on the edge you want the arrow to appear
12
+ # arrowLength is the length of the arrowhead
13
+ # arrowAngle is the angle in degrees that the arrowhead makes with the edge
14
+ # dotSize is the plotly scatter dot size you are using (used to even out line spacing when you have a mix of edge lengths)
15
+ def addEdge(start, end, edge_x, edge_y, lengthFrac=1, arrowPos = None, arrowLength=0.025, arrowAngle = 30, dotSize=20):
16
+
17
+ # Get start and end cartesian coordinates
18
+ x0, y0 = start
19
+ x1, y1 = end
20
+
21
+ # Incorporate the fraction of this segment covered by a dot into total reduction
22
+ length = math.sqrt( (x1-x0)**2 + (y1-y0)**2 )
23
+ dotSizeConversion = .0565/20 # length units per dot size
24
+ convertedDotDiameter = dotSize * dotSizeConversion
25
+ lengthFracReduction = convertedDotDiameter / length
26
+ lengthFrac = lengthFrac - lengthFracReduction
27
+
28
+ # If the line segment should not cover the entire distance, get actual start and end coords
29
+ skipX = (x1-x0)*(1-lengthFrac)
30
+ skipY = (y1-y0)*(1-lengthFrac)
31
+ x0 = x0 + skipX/2
32
+ x1 = x1 - skipX/2
33
+ y0 = y0 + skipY/2
34
+ y1 = y1 - skipY/2
35
+
36
+ # Append line corresponding to the edge
37
+ edge_x.append(x0)
38
+ edge_x.append(x1)
39
+ edge_x.append(None) # Prevents a line being drawn from end of this edge to start of next edge
40
+ edge_y.append(y0)
41
+ edge_y.append(y1)
42
+ edge_y.append(None)
43
+
44
+ # Draw arrow
45
+ if not arrowPos == None:
46
+
47
+ # Find the point of the arrow; assume is at end unless told middle
48
+ pointx = x1
49
+ pointy = y1
50
+
51
+ eta = math.degrees(math.atan((x1-x0)/(y1-y0))) if y1!=y0 else 90.0
52
+
53
+ if arrowPos == 'middle' or arrowPos == 'mid':
54
+ pointx = x0 + (x1-x0)/2
55
+ pointy = y0 + (y1-y0)/2
56
+
57
+ # Find the directions the arrows are pointing
58
+ signx = (x1-x0)/abs(x1-x0) if x1!=x0 else +1 #verify this once
59
+ signy = (y1-y0)/abs(y1-y0) if y1!=y0 else +1 #verified
60
+
61
+ # Append first arrowhead
62
+ dx = arrowLength * math.sin(math.radians(eta + arrowAngle))
63
+ dy = arrowLength * math.cos(math.radians(eta + arrowAngle))
64
+ edge_x.append(pointx)
65
+ edge_x.append(pointx - signx**2 * signy * dx)
66
+ edge_x.append(None)
67
+ edge_y.append(pointy)
68
+ edge_y.append(pointy - signx**2 * signy * dy)
69
+ edge_y.append(None)
70
+
71
+ # And second arrowhead
72
+ dx = arrowLength * math.sin(math.radians(eta - arrowAngle))
73
+ dy = arrowLength * math.cos(math.radians(eta - arrowAngle))
74
+ edge_x.append(pointx)
75
+ edge_x.append(pointx - signx**2 * signy * dx)
76
+ edge_x.append(None)
77
+ edge_y.append(pointy)
78
+ edge_y.append(pointy - signx**2 * signy * dy)
79
+ edge_y.append(None)
80
+
81
+
82
+ return edge_x, edge_y
83
+
84
+ def add_arrows(source_x: List[float], target_x: List[float], source_y: List[float], target_y: List[float],
85
+ arrowLength=0.025, arrowAngle=30):
86
+ pointx = list(map(lambda x: x[0] + (x[1] - x[0]) / 2, zip(source_x, target_x)))
87
+ pointy = list(map(lambda x: x[0] + (x[1] - x[0]) / 2, zip(source_y, target_y)))
88
+ etas = list(map(lambda x: math.degrees(math.atan((x[1] - x[0]) / (x[3] - x[2]))),
89
+ zip(source_x, target_x, source_y, target_y)))
90
+
91
+ signx = list(map(lambda x: (x[1] - x[0]) / abs(x[1] - x[0]), zip(source_x, target_x)))
92
+ signy = list(map(lambda x: (x[1] - x[0]) / abs(x[1] - x[0]), zip(source_y, target_y)))
93
+
94
+ dx = list(map(lambda x: arrowLength * math.sin(math.radians(x + arrowAngle)), etas))
95
+ dy = list(map(lambda x: arrowLength * math.cos(math.radians(x + arrowAngle)), etas))
96
+ none_spacer = [None for _ in range(len(pointx))]
97
+ arrow_line_x = list(map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointx, signx, signy, dx)))
98
+ arrow_line_y = list(map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointy, signx, signy, dy)))
99
+
100
+ arrow_line_1x_coords = list(chain(*zip(pointx, arrow_line_x, none_spacer)))
101
+ arrow_line_1y_coords = list(chain(*zip(pointy, arrow_line_y, none_spacer)))
102
+
103
+ dx = list(map(lambda x: arrowLength * math.sin(math.radians(x - arrowAngle)), etas))
104
+ dy = list(map(lambda x: arrowLength * math.cos(math.radians(x - arrowAngle)), etas))
105
+ none_spacer = [None for _ in range(len(pointx))]
106
+ arrow_line_x = list(map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointx, signx, signy, dx)))
107
+ arrow_line_y = list(map(lambda x: x[0] - x[1] ** 2 * x[2] * x[3], zip(pointy, signx, signy, dy)))
108
+
109
+ arrow_line_2x_coords = list(chain(*zip(pointx, arrow_line_x, none_spacer)))
110
+ arrow_line_2y_coords = list(chain(*zip(pointy, arrow_line_y, none_spacer)))
111
+
112
+ x_arrows = arrow_line_1x_coords + arrow_line_2x_coords
113
+ y_arrows = arrow_line_1y_coords + arrow_line_2y_coords
114
+
115
+ return x_arrows, y_arrows
116
+
117
+ @st.cache(allow_output_mutation=True)
118
+ def get_pipeline_graph(pipeline):
119
+ # Controls for how the graph is drawn
120
+ nodeColor = 'Blue'
121
+ nodeSize = 20
122
+ lineWidth = 2
123
+ lineColor = '#000000'
124
+
125
+ G = pipeline.graph
126
+
127
+ pos = nx.spring_layout(G)
128
+
129
+ for node in G.nodes:
130
+ G.nodes[node]['pos'] = list(pos[node])
131
+
132
+ # Make list of nodes for plotly
133
+ node_x = []
134
+ node_y = []
135
+ for node in G.nodes():
136
+ x, y = G.nodes[node]['pos']
137
+ node_x.append(x)
138
+ node_y.append(y)
139
+
140
+ # Make a list of edges for plotly, including line segments that result in arrowheads
141
+ edge_x = []
142
+ edge_y = []
143
+ for edge in G.edges():
144
+ start = G.nodes[edge[0]]['pos']
145
+ end = G.nodes[edge[1]]['pos']
146
+ # addEdge(start, end, edge_x, edge_y, lengthFrac=1, arrowPos = None, arrowLength=0.025, arrowAngle = 30, dotSize=20)
147
+ edge_x, edge_y = addEdge(start, end, edge_x, edge_y, lengthFrac=.8, arrowPos='end', arrowLength=.04, arrowAngle=30, dotSize=nodeSize)
148
+
149
+
150
+ edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=lineWidth, color=lineColor), hoverinfo='none', mode='lines')
151
+
152
+
153
+ node_trace = go.Scatter(x=node_x, y=node_y, mode='markers', hoverinfo='text', marker=dict(showscale=False, color = nodeColor, size=nodeSize))
154
+
155
+ fig = go.Figure(data=[edge_trace, node_trace],
156
+ layout=go.Layout(
157
+ showlegend=False,
158
+ hovermode='closest',
159
+ margin=dict(b=20,l=5,r=5,t=40),
160
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
161
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
162
+ )
163
+
164
+ # Note: if you don't use fixed ratio axes, the arrows won't be symmetrical
165
+ fig.update_layout(yaxis = dict(scaleanchor = "x", scaleratio = 1), plot_bgcolor='rgb(255,255,255)')
166
+
167
+ return fig
interface/pages.py CHANGED
@@ -36,7 +36,7 @@ def page_search(container):
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
 
39
- # component_show_pipeline(container, st.session_state["search_pipeline"])
40
 
41
  if st.button("Search"):
42
  st.session_state["search_results"] = search(
@@ -53,7 +53,7 @@ def page_index(container):
53
  with container:
54
  st.title("Index time!")
55
 
56
- # component_show_pipeline(container, st.session_state["index_pipeline"])
57
 
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
 
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
 
39
+ component_show_pipeline(st.session_state["search_pipeline"])
40
 
41
  if st.button("Search"):
42
  st.session_state["search_results"] = search(
 
53
  with container:
54
  st.title("Index time!")
55
 
56
+ component_show_pipeline(st.session_state["index_pipeline"])
57
 
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
interface/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import core.pipelines as pipelines_functions
2
+ from inspect import getmembers, isfunction
3
+
4
+ def get_pipelines():
5
+ pipeline_names, pipeline_funcs = list(
6
+ zip(*getmembers(pipelines_functions, isfunction))
7
+ )
8
+ pipeline_names = [
9
+ " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
10
+ ]
11
+ return pipeline_names, pipeline_funcs
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  streamlit
2
  streamlit_option_menu
3
  farm-haystack
4
- black
 
 
1
  streamlit
2
  streamlit_option_menu
3
  farm-haystack
4
+ black
5
+ plotly