ugmSorcero commited on
Commit
39503cb
β€’
1 Parent(s): 8d3aacc

Adds linter and fixes linting

Browse files
app.py CHANGED
@@ -5,9 +5,7 @@ st.set_page_config(
5
  page_icon="πŸ”Ž",
6
  layout="wide",
7
  initial_sidebar_state="expanded",
8
- menu_items={
9
- 'About': "https://github.com/ugm2/neural-search-demo"
10
- }
11
  )
12
 
13
  from streamlit_option_menu import option_menu
 
5
  page_icon="πŸ”Ž",
6
  layout="wide",
7
  initial_sidebar_state="expanded",
8
+ menu_items={"About": "https://github.com/ugm2/neural-search-demo"},
 
 
9
  )
10
 
11
  from streamlit_option_menu import option_menu
core/pipelines.py CHANGED
@@ -9,9 +9,10 @@ from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
9
  from haystack.nodes.preprocessor import PreProcessor
10
  import streamlit as st
11
 
 
12
  @st.cache(allow_output_mutation=True)
13
  def keyword_search(
14
- index='documents',
15
  ):
16
  document_store = InMemoryDocumentStore(index=index)
17
  keyword_retriever = TfidfRetriever(document_store=(document_store))
@@ -31,16 +32,25 @@ def keyword_search(
31
  # INDEXING PIPELINE
32
  index_pipeline = Pipeline()
33
  index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
34
- index_pipeline.add_node(keyword_retriever, name="TfidfRetriever", inputs=["Preprocessor"])
 
 
35
  index_pipeline.add_node(
36
  document_store, name="DocumentStore", inputs=["TfidfRetriever"]
37
  )
38
 
39
  return search_pipeline, index_pipeline
40
 
41
- @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None}, allow_output_mutation=True)
 
 
 
 
 
 
 
42
  def dense_passage_retrieval(
43
- index='documents',
44
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
45
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
46
  ):
 
9
  from haystack.nodes.preprocessor import PreProcessor
10
  import streamlit as st
11
 
12
+
13
  @st.cache(allow_output_mutation=True)
14
  def keyword_search(
15
+ index="documents",
16
  ):
17
  document_store = InMemoryDocumentStore(index=index)
18
  keyword_retriever = TfidfRetriever(document_store=(document_store))
 
32
  # INDEXING PIPELINE
33
  index_pipeline = Pipeline()
34
  index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
35
+ index_pipeline.add_node(
36
+ keyword_retriever, name="TfidfRetriever", inputs=["Preprocessor"]
37
+ )
38
  index_pipeline.add_node(
39
  document_store, name="DocumentStore", inputs=["TfidfRetriever"]
40
  )
41
 
42
  return search_pipeline, index_pipeline
43
 
44
+
45
+ @st.cache(
46
+ hash_funcs={
47
+ tokenizers.Tokenizer: lambda _: None,
48
+ tokenizers.AddedToken: lambda _: None,
49
+ },
50
+ allow_output_mutation=True,
51
+ )
52
  def dense_passage_retrieval(
53
+ index="documents",
54
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
55
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
56
  ):
core/search_index.py CHANGED
@@ -6,9 +6,9 @@ def format_docs(documents):
6
  """Given a list of documents, format the documents and return the documents and doc ids."""
7
  db_docs: list = []
8
  for doc in documents:
9
- doc_id = doc['id'] if doc['id'] is not None else str(uuid.uuid4())
10
  db_doc = {
11
- "content": doc['text'],
12
  "content_type": "text",
13
  "id": str(uuid.uuid4()),
14
  "meta": {"id": doc_id},
@@ -16,11 +16,13 @@ def format_docs(documents):
16
  db_docs.append(Document(**db_doc))
17
  return db_docs, [doc.meta["id"] for doc in db_docs]
18
 
 
19
  def index(documents, pipeline):
20
  documents, doc_ids = format_docs(documents)
21
  pipeline.run(documents=documents)
22
  return doc_ids
23
 
 
24
  def search(queries, pipeline):
25
  results = []
26
  matches_queries = pipeline.run_batch(queries=queries)
@@ -35,10 +37,12 @@ def search(queries, pipeline):
35
  "text": res.content,
36
  "score": res.score,
37
  "id": res.meta["id"],
38
- "fragment_id": res.id
39
  }
40
  )
41
  if not score_is_empty:
42
- query_results = sorted(query_results, key=lambda x: x["score"], reverse=True)
 
 
43
  results.append(query_results)
44
- return results
 
6
  """Given a list of documents, format the documents and return the documents and doc ids."""
7
  db_docs: list = []
8
  for doc in documents:
9
+ doc_id = doc["id"] if doc["id"] is not None else str(uuid.uuid4())
10
  db_doc = {
11
+ "content": doc["text"],
12
  "content_type": "text",
13
  "id": str(uuid.uuid4()),
14
  "meta": {"id": doc_id},
 
16
  db_docs.append(Document(**db_doc))
17
  return db_docs, [doc.meta["id"] for doc in db_docs]
18
 
19
+
20
  def index(documents, pipeline):
21
  documents, doc_ids = format_docs(documents)
22
  pipeline.run(documents=documents)
23
  return doc_ids
24
 
25
+
26
  def search(queries, pipeline):
27
  results = []
28
  matches_queries = pipeline.run_batch(queries=queries)
 
37
  "text": res.content,
38
  "score": res.score,
39
  "id": res.meta["id"],
40
+ "fragment_id": res.id,
41
  }
42
  )
43
  if not score_is_empty:
44
+ query_results = sorted(
45
+ query_results, key=lambda x: x["score"], reverse=True
46
+ )
47
  results.append(query_results)
48
+ return results
interface/components.py CHANGED
@@ -3,36 +3,47 @@ import core.pipelines as pipelines_functions
3
  from inspect import getmembers, isfunction
4
  from networkx.drawing.nx_agraph import to_agraph
5
 
 
6
  def component_select_pipeline(container):
7
- pipeline_names, pipeline_funcs = list(zip(*getmembers(pipelines_functions, isfunction)))
8
- pipeline_names = [' '.join([n.capitalize() for n in name.split('_')]) for name in pipeline_names]
 
 
 
 
9
  with container:
10
  selected_pipeline = st.selectbox(
11
- 'Select pipeline',
12
  pipeline_names,
13
- index=pipeline_names.index('Keyword Search') if 'Keyword Search' in pipeline_names else 0
 
 
14
  )
15
- st.session_state['search_pipeline'], \
16
- st.session_state['index_pipeline'] = \
17
- pipeline_funcs[pipeline_names.index(selected_pipeline)]()
 
 
18
 
19
  def component_show_pipeline(container, pipeline):
20
  """Draw the pipeline"""
21
- with st.expander('Show pipeline'):
22
  graphviz = to_agraph(pipeline.graph)
23
  graphviz.layout("dot")
24
  st.graphviz_chart(graphviz.string())
25
-
 
26
  def component_show_search_result(container, results):
27
  with container:
28
  for idx, document in enumerate(results):
29
  st.markdown(f"### Match {idx+1}")
30
  st.markdown(f"**Text**: {document['text']}")
31
  st.markdown(f"**Document**: {document['id']}")
32
- if document['score'] is not None:
33
  st.markdown(f"**Score**: {document['score']:.3f}")
34
  st.markdown("---")
35
 
 
36
  def component_text_input(container):
37
  """Draw the Text Input widget"""
38
  with container:
@@ -48,7 +59,6 @@ def component_text_input(container):
48
  else:
49
  break
50
  corpus = [
51
- {"text": doc["text"], "id": doc_id}
52
- for doc_id, doc in enumerate(texts)
53
  ]
54
- return corpus
 
3
  from inspect import getmembers, isfunction
4
  from networkx.drawing.nx_agraph import to_agraph
5
 
6
+
7
  def component_select_pipeline(container):
8
+ pipeline_names, pipeline_funcs = list(
9
+ zip(*getmembers(pipelines_functions, isfunction))
10
+ )
11
+ pipeline_names = [
12
+ " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
13
+ ]
14
  with container:
15
  selected_pipeline = st.selectbox(
16
+ "Select pipeline",
17
  pipeline_names,
18
+ index=pipeline_names.index("Keyword Search")
19
+ if "Keyword Search" in pipeline_names
20
+ else 0,
21
  )
22
+ (
23
+ st.session_state["search_pipeline"],
24
+ st.session_state["index_pipeline"],
25
+ ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
26
+
27
 
28
  def component_show_pipeline(container, pipeline):
29
  """Draw the pipeline"""
30
+ with st.expander("Show pipeline"):
31
  graphviz = to_agraph(pipeline.graph)
32
  graphviz.layout("dot")
33
  st.graphviz_chart(graphviz.string())
34
+
35
+
36
  def component_show_search_result(container, results):
37
  with container:
38
  for idx, document in enumerate(results):
39
  st.markdown(f"### Match {idx+1}")
40
  st.markdown(f"**Text**: {document['text']}")
41
  st.markdown(f"**Document**: {document['id']}")
42
+ if document["score"] is not None:
43
  st.markdown(f"**Score**: {document['score']:.3f}")
44
  st.markdown("---")
45
 
46
+
47
  def component_text_input(container):
48
  """Draw the Text Input widget"""
49
  with container:
 
59
  else:
60
  break
61
  corpus = [
62
+ {"text": doc["text"], "id": doc_id} for doc_id, doc in enumerate(texts)
 
63
  ]
64
+ return corpus
interface/pages.py CHANGED
@@ -1,7 +1,12 @@
1
  import streamlit as st
2
  from streamlit_option_menu import option_menu
3
  from core.search_index import index, search
4
- from interface.components import component_show_pipeline, component_show_search_result, component_text_input
 
 
 
 
 
5
 
6
  def page_landing_page(container):
7
  with container:
@@ -22,33 +27,34 @@ def page_landing_page(container):
22
  "\n - Include file/url indexing"
23
  "\n - [Optional] Include text to audio to read responses"
24
  )
25
-
 
26
  def page_search(container):
27
  with container:
28
  st.title("Query me!")
29
-
30
  ## SEARCH ##
31
  query = st.text_input("Query")
32
-
33
- component_show_pipeline(container, st.session_state['search_pipeline'])
34
-
35
  if st.button("Search"):
36
- st.session_state['search_results'] = search(
37
  queries=[query],
38
- pipeline=st.session_state['search_pipeline'],
39
  )
40
- if 'search_results' in st.session_state:
41
  component_show_search_result(
42
- container=container,
43
- results=st.session_state['search_results'][0]
44
  )
45
-
 
46
  def page_index(container):
47
  with container:
48
  st.title("Index time!")
49
-
50
- component_show_pipeline(container, st.session_state['index_pipeline'])
51
-
52
  input_funcs = {
53
  "Raw Text": (component_text_input, "card-text"),
54
  }
@@ -60,15 +66,15 @@ def page_index(container):
60
  default_index=0,
61
  orientation="horizontal",
62
  )
63
-
64
  corpus = input_funcs[selected_input][0](container)
65
-
66
  if len(corpus) > 0:
67
  index_results = None
68
  if st.button("Index"):
69
  index_results = index(
70
  corpus,
71
- st.session_state['index_pipeline'],
72
  )
73
  if index_results:
74
- st.write(index_results)
 
1
  import streamlit as st
2
  from streamlit_option_menu import option_menu
3
  from core.search_index import index, search
4
+ from interface.components import (
5
+ component_show_pipeline,
6
+ component_show_search_result,
7
+ component_text_input,
8
+ )
9
+
10
 
11
  def page_landing_page(container):
12
  with container:
 
27
  "\n - Include file/url indexing"
28
  "\n - [Optional] Include text to audio to read responses"
29
  )
30
+
31
+
32
  def page_search(container):
33
  with container:
34
  st.title("Query me!")
35
+
36
  ## SEARCH ##
37
  query = st.text_input("Query")
38
+
39
+ component_show_pipeline(container, st.session_state["search_pipeline"])
40
+
41
  if st.button("Search"):
42
+ st.session_state["search_results"] = search(
43
  queries=[query],
44
+ pipeline=st.session_state["search_pipeline"],
45
  )
46
+ if "search_results" in st.session_state:
47
  component_show_search_result(
48
+ container=container, results=st.session_state["search_results"][0]
 
49
  )
50
+
51
+
52
  def page_index(container):
53
  with container:
54
  st.title("Index time!")
55
+
56
+ component_show_pipeline(container, st.session_state["index_pipeline"])
57
+
58
  input_funcs = {
59
  "Raw Text": (component_text_input, "card-text"),
60
  }
 
66
  default_index=0,
67
  orientation="horizontal",
68
  )
69
+
70
  corpus = input_funcs[selected_input][0](container)
71
+
72
  if len(corpus) > 0:
73
  index_results = None
74
  if st.button("Index"):
75
  index_results = index(
76
  corpus,
77
+ st.session_state["index_pipeline"],
78
  )
79
  if index_results:
80
+ st.write(index_results)
linter.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python -m black app.py interface core
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  streamlit
2
  streamlit_option_menu
3
  farm-haystack
4
- pygraphviz
 
 
1
  streamlit
2
  streamlit_option_menu
3
  farm-haystack
4
+ pygraphviz
5
+ black