Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		ugmSorcero
		
	commited on
		
		
					Commit 
							
							Β·
						
						39503cb
	
1
								Parent(s):
							
							8d3aacc
								
Adds linter and fixes linting
Browse files- app.py +1 -3
- core/pipelines.py +14 -4
- core/search_index.py +9 -5
- interface/components.py +23 -13
- interface/pages.py +25 -19
- linter.sh +1 -0
- requirements.txt +2 -1
    	
        app.py
    CHANGED
    
    | @@ -5,9 +5,7 @@ st.set_page_config( | |
| 5 | 
             
                page_icon="π",
         | 
| 6 | 
             
                layout="wide",
         | 
| 7 | 
             
                initial_sidebar_state="expanded",
         | 
| 8 | 
            -
                menu_items={
         | 
| 9 | 
            -
                    'About': "https://github.com/ugm2/neural-search-demo"
         | 
| 10 | 
            -
                }
         | 
| 11 | 
             
            )
         | 
| 12 |  | 
| 13 | 
             
            from streamlit_option_menu import option_menu
         | 
|  | |
| 5 | 
             
                page_icon="π",
         | 
| 6 | 
             
                layout="wide",
         | 
| 7 | 
             
                initial_sidebar_state="expanded",
         | 
| 8 | 
            +
                menu_items={"About": "https://github.com/ugm2/neural-search-demo"},
         | 
|  | |
|  | |
| 9 | 
             
            )
         | 
| 10 |  | 
| 11 | 
             
            from streamlit_option_menu import option_menu
         | 
    	
        core/pipelines.py
    CHANGED
    
    | @@ -9,9 +9,10 @@ from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever | |
| 9 | 
             
            from haystack.nodes.preprocessor import PreProcessor
         | 
| 10 | 
             
            import streamlit as st
         | 
| 11 |  | 
|  | |
| 12 | 
             
            @st.cache(allow_output_mutation=True)
         | 
| 13 | 
             
            def keyword_search(
         | 
| 14 | 
            -
                index= | 
| 15 | 
             
            ):
         | 
| 16 | 
             
                document_store = InMemoryDocumentStore(index=index)
         | 
| 17 | 
             
                keyword_retriever = TfidfRetriever(document_store=(document_store))
         | 
| @@ -31,16 +32,25 @@ def keyword_search( | |
| 31 | 
             
                # INDEXING PIPELINE
         | 
| 32 | 
             
                index_pipeline = Pipeline()
         | 
| 33 | 
             
                index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
         | 
| 34 | 
            -
                index_pipeline.add_node( | 
|  | |
|  | |
| 35 | 
             
                index_pipeline.add_node(
         | 
| 36 | 
             
                    document_store, name="DocumentStore", inputs=["TfidfRetriever"]
         | 
| 37 | 
             
                )
         | 
| 38 |  | 
| 39 | 
             
                return search_pipeline, index_pipeline
         | 
| 40 |  | 
| 41 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 42 | 
             
            def dense_passage_retrieval(
         | 
| 43 | 
            -
                index= | 
| 44 | 
             
                query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
         | 
| 45 | 
             
                passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
         | 
| 46 | 
             
            ):
         | 
|  | |
| 9 | 
             
            from haystack.nodes.preprocessor import PreProcessor
         | 
| 10 | 
             
            import streamlit as st
         | 
| 11 |  | 
| 12 | 
            +
             | 
| 13 | 
             
            @st.cache(allow_output_mutation=True)
         | 
| 14 | 
             
            def keyword_search(
         | 
| 15 | 
            +
                index="documents",
         | 
| 16 | 
             
            ):
         | 
| 17 | 
             
                document_store = InMemoryDocumentStore(index=index)
         | 
| 18 | 
             
                keyword_retriever = TfidfRetriever(document_store=(document_store))
         | 
|  | |
| 32 | 
             
                # INDEXING PIPELINE
         | 
| 33 | 
             
                index_pipeline = Pipeline()
         | 
| 34 | 
             
                index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
         | 
| 35 | 
            +
                index_pipeline.add_node(
         | 
| 36 | 
            +
                    keyword_retriever, name="TfidfRetriever", inputs=["Preprocessor"]
         | 
| 37 | 
            +
                )
         | 
| 38 | 
             
                index_pipeline.add_node(
         | 
| 39 | 
             
                    document_store, name="DocumentStore", inputs=["TfidfRetriever"]
         | 
| 40 | 
             
                )
         | 
| 41 |  | 
| 42 | 
             
                return search_pipeline, index_pipeline
         | 
| 43 |  | 
| 44 | 
            +
             | 
| 45 | 
            +
            @st.cache(
         | 
| 46 | 
            +
                hash_funcs={
         | 
| 47 | 
            +
                    tokenizers.Tokenizer: lambda _: None,
         | 
| 48 | 
            +
                    tokenizers.AddedToken: lambda _: None,
         | 
| 49 | 
            +
                },
         | 
| 50 | 
            +
                allow_output_mutation=True,
         | 
| 51 | 
            +
            )
         | 
| 52 | 
             
            def dense_passage_retrieval(
         | 
| 53 | 
            +
                index="documents",
         | 
| 54 | 
             
                query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
         | 
| 55 | 
             
                passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
         | 
| 56 | 
             
            ):
         | 
    	
        core/search_index.py
    CHANGED
    
    | @@ -6,9 +6,9 @@ def format_docs(documents): | |
| 6 | 
             
                """Given a list of documents, format the documents and return the documents and doc ids."""
         | 
| 7 | 
             
                db_docs: list = []
         | 
| 8 | 
             
                for doc in documents:
         | 
| 9 | 
            -
                    doc_id = doc[ | 
| 10 | 
             
                    db_doc = {
         | 
| 11 | 
            -
                        "content": doc[ | 
| 12 | 
             
                        "content_type": "text",
         | 
| 13 | 
             
                        "id": str(uuid.uuid4()),
         | 
| 14 | 
             
                        "meta": {"id": doc_id},
         | 
| @@ -16,11 +16,13 @@ def format_docs(documents): | |
| 16 | 
             
                    db_docs.append(Document(**db_doc))
         | 
| 17 | 
             
                return db_docs, [doc.meta["id"] for doc in db_docs]
         | 
| 18 |  | 
|  | |
| 19 | 
             
            def index(documents, pipeline):
         | 
| 20 | 
             
                documents, doc_ids = format_docs(documents)
         | 
| 21 | 
             
                pipeline.run(documents=documents)
         | 
| 22 | 
             
                return doc_ids
         | 
| 23 |  | 
|  | |
| 24 | 
             
            def search(queries, pipeline):
         | 
| 25 | 
             
                results = []
         | 
| 26 | 
             
                matches_queries = pipeline.run_batch(queries=queries)
         | 
| @@ -35,10 +37,12 @@ def search(queries, pipeline): | |
| 35 | 
             
                                "text": res.content,
         | 
| 36 | 
             
                                "score": res.score,
         | 
| 37 | 
             
                                "id": res.meta["id"],
         | 
| 38 | 
            -
                                "fragment_id": res.id
         | 
| 39 | 
             
                            }
         | 
| 40 | 
             
                        )
         | 
| 41 | 
             
                    if not score_is_empty:
         | 
| 42 | 
            -
                        query_results = sorted( | 
|  | |
|  | |
| 43 | 
             
                    results.append(query_results)
         | 
| 44 | 
            -
                return results
         | 
|  | |
| 6 | 
             
                """Given a list of documents, format the documents and return the documents and doc ids."""
         | 
| 7 | 
             
                db_docs: list = []
         | 
| 8 | 
             
                for doc in documents:
         | 
| 9 | 
            +
                    doc_id = doc["id"] if doc["id"] is not None else str(uuid.uuid4())
         | 
| 10 | 
             
                    db_doc = {
         | 
| 11 | 
            +
                        "content": doc["text"],
         | 
| 12 | 
             
                        "content_type": "text",
         | 
| 13 | 
             
                        "id": str(uuid.uuid4()),
         | 
| 14 | 
             
                        "meta": {"id": doc_id},
         | 
|  | |
| 16 | 
             
                    db_docs.append(Document(**db_doc))
         | 
| 17 | 
             
                return db_docs, [doc.meta["id"] for doc in db_docs]
         | 
| 18 |  | 
| 19 | 
            +
             | 
| 20 | 
             
            def index(documents, pipeline):
         | 
| 21 | 
             
                documents, doc_ids = format_docs(documents)
         | 
| 22 | 
             
                pipeline.run(documents=documents)
         | 
| 23 | 
             
                return doc_ids
         | 
| 24 |  | 
| 25 | 
            +
             | 
| 26 | 
             
            def search(queries, pipeline):
         | 
| 27 | 
             
                results = []
         | 
| 28 | 
             
                matches_queries = pipeline.run_batch(queries=queries)
         | 
|  | |
| 37 | 
             
                                "text": res.content,
         | 
| 38 | 
             
                                "score": res.score,
         | 
| 39 | 
             
                                "id": res.meta["id"],
         | 
| 40 | 
            +
                                "fragment_id": res.id,
         | 
| 41 | 
             
                            }
         | 
| 42 | 
             
                        )
         | 
| 43 | 
             
                    if not score_is_empty:
         | 
| 44 | 
            +
                        query_results = sorted(
         | 
| 45 | 
            +
                            query_results, key=lambda x: x["score"], reverse=True
         | 
| 46 | 
            +
                        )
         | 
| 47 | 
             
                    results.append(query_results)
         | 
| 48 | 
            +
                return results
         | 
    	
        interface/components.py
    CHANGED
    
    | @@ -3,36 +3,47 @@ import core.pipelines as pipelines_functions | |
| 3 | 
             
            from inspect import getmembers, isfunction
         | 
| 4 | 
             
            from networkx.drawing.nx_agraph import to_agraph
         | 
| 5 |  | 
|  | |
| 6 | 
             
            def component_select_pipeline(container):
         | 
| 7 | 
            -
                pipeline_names, pipeline_funcs = list( | 
| 8 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
                with container:
         | 
| 10 | 
             
                    selected_pipeline = st.selectbox(
         | 
| 11 | 
            -
                         | 
| 12 | 
             
                        pipeline_names,
         | 
| 13 | 
            -
                        index=pipeline_names.index( | 
|  | |
|  | |
| 14 | 
             
                    )
         | 
| 15 | 
            -
                     | 
| 16 | 
            -
                        st.session_state[ | 
| 17 | 
            -
             | 
|  | |
|  | |
| 18 |  | 
| 19 | 
             
            def component_show_pipeline(container, pipeline):
         | 
| 20 | 
             
                """Draw the pipeline"""
         | 
| 21 | 
            -
                with st.expander( | 
| 22 | 
             
                    graphviz = to_agraph(pipeline.graph)
         | 
| 23 | 
             
                    graphviz.layout("dot")
         | 
| 24 | 
             
                    st.graphviz_chart(graphviz.string())
         | 
| 25 | 
            -
             | 
|  | |
| 26 | 
             
            def component_show_search_result(container, results):
         | 
| 27 | 
             
                with container:
         | 
| 28 | 
             
                    for idx, document in enumerate(results):
         | 
| 29 | 
             
                        st.markdown(f"### Match {idx+1}")
         | 
| 30 | 
             
                        st.markdown(f"**Text**: {document['text']}")
         | 
| 31 | 
             
                        st.markdown(f"**Document**: {document['id']}")
         | 
| 32 | 
            -
                        if document[ | 
| 33 | 
             
                            st.markdown(f"**Score**: {document['score']:.3f}")
         | 
| 34 | 
             
                        st.markdown("---")
         | 
| 35 |  | 
|  | |
| 36 | 
             
            def component_text_input(container):
         | 
| 37 | 
             
                """Draw the Text Input widget"""
         | 
| 38 | 
             
                with container:
         | 
| @@ -48,7 +59,6 @@ def component_text_input(container): | |
| 48 | 
             
                            else:
         | 
| 49 | 
             
                                break
         | 
| 50 | 
             
                    corpus = [
         | 
| 51 | 
            -
                        {"text": doc["text"], "id": doc_id}
         | 
| 52 | 
            -
                        for doc_id, doc in enumerate(texts)
         | 
| 53 | 
             
                    ]
         | 
| 54 | 
            -
                    return corpus
         | 
|  | |
| 3 | 
             
            from inspect import getmembers, isfunction
         | 
| 4 | 
             
            from networkx.drawing.nx_agraph import to_agraph
         | 
| 5 |  | 
| 6 | 
            +
             | 
| 7 | 
             
            def component_select_pipeline(container):
         | 
| 8 | 
            +
                pipeline_names, pipeline_funcs = list(
         | 
| 9 | 
            +
                    zip(*getmembers(pipelines_functions, isfunction))
         | 
| 10 | 
            +
                )
         | 
| 11 | 
            +
                pipeline_names = [
         | 
| 12 | 
            +
                    " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
         | 
| 13 | 
            +
                ]
         | 
| 14 | 
             
                with container:
         | 
| 15 | 
             
                    selected_pipeline = st.selectbox(
         | 
| 16 | 
            +
                        "Select pipeline",
         | 
| 17 | 
             
                        pipeline_names,
         | 
| 18 | 
            +
                        index=pipeline_names.index("Keyword Search")
         | 
| 19 | 
            +
                        if "Keyword Search" in pipeline_names
         | 
| 20 | 
            +
                        else 0,
         | 
| 21 | 
             
                    )
         | 
| 22 | 
            +
                    (
         | 
| 23 | 
            +
                        st.session_state["search_pipeline"],
         | 
| 24 | 
            +
                        st.session_state["index_pipeline"],
         | 
| 25 | 
            +
                    ) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
         | 
| 26 | 
            +
             | 
| 27 |  | 
| 28 | 
             
            def component_show_pipeline(container, pipeline):
         | 
| 29 | 
             
                """Draw the pipeline"""
         | 
| 30 | 
            +
                with st.expander("Show pipeline"):
         | 
| 31 | 
             
                    graphviz = to_agraph(pipeline.graph)
         | 
| 32 | 
             
                    graphviz.layout("dot")
         | 
| 33 | 
             
                    st.graphviz_chart(graphviz.string())
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
             
            def component_show_search_result(container, results):
         | 
| 37 | 
             
                with container:
         | 
| 38 | 
             
                    for idx, document in enumerate(results):
         | 
| 39 | 
             
                        st.markdown(f"### Match {idx+1}")
         | 
| 40 | 
             
                        st.markdown(f"**Text**: {document['text']}")
         | 
| 41 | 
             
                        st.markdown(f"**Document**: {document['id']}")
         | 
| 42 | 
            +
                        if document["score"] is not None:
         | 
| 43 | 
             
                            st.markdown(f"**Score**: {document['score']:.3f}")
         | 
| 44 | 
             
                        st.markdown("---")
         | 
| 45 |  | 
| 46 | 
            +
             | 
| 47 | 
             
            def component_text_input(container):
         | 
| 48 | 
             
                """Draw the Text Input widget"""
         | 
| 49 | 
             
                with container:
         | 
|  | |
| 59 | 
             
                            else:
         | 
| 60 | 
             
                                break
         | 
| 61 | 
             
                    corpus = [
         | 
| 62 | 
            +
                        {"text": doc["text"], "id": doc_id} for doc_id, doc in enumerate(texts)
         | 
|  | |
| 63 | 
             
                    ]
         | 
| 64 | 
            +
                    return corpus
         | 
    	
        interface/pages.py
    CHANGED
    
    | @@ -1,7 +1,12 @@ | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
             
            from streamlit_option_menu import option_menu
         | 
| 3 | 
             
            from core.search_index import index, search
         | 
| 4 | 
            -
            from interface.components import  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
             
            def page_landing_page(container):
         | 
| 7 | 
             
                with container:
         | 
| @@ -22,33 +27,34 @@ def page_landing_page(container): | |
| 22 | 
             
                        "\n  - Include file/url indexing"
         | 
| 23 | 
             
                        "\n  - [Optional] Include text to audio to read responses"
         | 
| 24 | 
             
                    )
         | 
| 25 | 
            -
             | 
|  | |
| 26 | 
             
            def page_search(container):
         | 
| 27 | 
             
                with container:
         | 
| 28 | 
             
                    st.title("Query me!")
         | 
| 29 | 
            -
             | 
| 30 | 
             
                    ## SEARCH ##
         | 
| 31 | 
             
                    query = st.text_input("Query")
         | 
| 32 | 
            -
             | 
| 33 | 
            -
                    component_show_pipeline(container, st.session_state[ | 
| 34 | 
            -
             | 
| 35 | 
             
                    if st.button("Search"):
         | 
| 36 | 
            -
                        st.session_state[ | 
| 37 | 
             
                            queries=[query],
         | 
| 38 | 
            -
                            pipeline=st.session_state[ | 
| 39 | 
             
                        )
         | 
| 40 | 
            -
                    if  | 
| 41 | 
             
                        component_show_search_result(
         | 
| 42 | 
            -
                            container=container,
         | 
| 43 | 
            -
                            results=st.session_state['search_results'][0]
         | 
| 44 | 
             
                        )
         | 
| 45 | 
            -
             | 
|  | |
| 46 | 
             
            def page_index(container):
         | 
| 47 | 
             
                with container:
         | 
| 48 | 
             
                    st.title("Index time!")
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    component_show_pipeline(container, st.session_state[ | 
| 51 | 
            -
             | 
| 52 | 
             
                    input_funcs = {
         | 
| 53 | 
             
                        "Raw Text": (component_text_input, "card-text"),
         | 
| 54 | 
             
                    }
         | 
| @@ -60,15 +66,15 @@ def page_index(container): | |
| 60 | 
             
                        default_index=0,
         | 
| 61 | 
             
                        orientation="horizontal",
         | 
| 62 | 
             
                    )
         | 
| 63 | 
            -
             | 
| 64 | 
             
                    corpus = input_funcs[selected_input][0](container)
         | 
| 65 | 
            -
             | 
| 66 | 
             
                    if len(corpus) > 0:
         | 
| 67 | 
             
                        index_results = None
         | 
| 68 | 
             
                        if st.button("Index"):
         | 
| 69 | 
             
                            index_results = index(
         | 
| 70 | 
             
                                corpus,
         | 
| 71 | 
            -
                                st.session_state[ | 
| 72 | 
             
                            )
         | 
| 73 | 
             
                        if index_results:
         | 
| 74 | 
            -
                            st.write(index_results)
         | 
|  | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
             
            from streamlit_option_menu import option_menu
         | 
| 3 | 
             
            from core.search_index import index, search
         | 
| 4 | 
            +
            from interface.components import (
         | 
| 5 | 
            +
                component_show_pipeline,
         | 
| 6 | 
            +
                component_show_search_result,
         | 
| 7 | 
            +
                component_text_input,
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
             | 
| 10 |  | 
| 11 | 
             
            def page_landing_page(container):
         | 
| 12 | 
             
                with container:
         | 
|  | |
| 27 | 
             
                        "\n  - Include file/url indexing"
         | 
| 28 | 
             
                        "\n  - [Optional] Include text to audio to read responses"
         | 
| 29 | 
             
                    )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
             
            def page_search(container):
         | 
| 33 | 
             
                with container:
         | 
| 34 | 
             
                    st.title("Query me!")
         | 
| 35 | 
            +
             | 
| 36 | 
             
                    ## SEARCH ##
         | 
| 37 | 
             
                    query = st.text_input("Query")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    component_show_pipeline(container, st.session_state["search_pipeline"])
         | 
| 40 | 
            +
             | 
| 41 | 
             
                    if st.button("Search"):
         | 
| 42 | 
            +
                        st.session_state["search_results"] = search(
         | 
| 43 | 
             
                            queries=[query],
         | 
| 44 | 
            +
                            pipeline=st.session_state["search_pipeline"],
         | 
| 45 | 
             
                        )
         | 
| 46 | 
            +
                    if "search_results" in st.session_state:
         | 
| 47 | 
             
                        component_show_search_result(
         | 
| 48 | 
            +
                            container=container, results=st.session_state["search_results"][0]
         | 
|  | |
| 49 | 
             
                        )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
             
            def page_index(container):
         | 
| 53 | 
             
                with container:
         | 
| 54 | 
             
                    st.title("Index time!")
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    component_show_pipeline(container, st.session_state["index_pipeline"])
         | 
| 57 | 
            +
             | 
| 58 | 
             
                    input_funcs = {
         | 
| 59 | 
             
                        "Raw Text": (component_text_input, "card-text"),
         | 
| 60 | 
             
                    }
         | 
|  | |
| 66 | 
             
                        default_index=0,
         | 
| 67 | 
             
                        orientation="horizontal",
         | 
| 68 | 
             
                    )
         | 
| 69 | 
            +
             | 
| 70 | 
             
                    corpus = input_funcs[selected_input][0](container)
         | 
| 71 | 
            +
             | 
| 72 | 
             
                    if len(corpus) > 0:
         | 
| 73 | 
             
                        index_results = None
         | 
| 74 | 
             
                        if st.button("Index"):
         | 
| 75 | 
             
                            index_results = index(
         | 
| 76 | 
             
                                corpus,
         | 
| 77 | 
            +
                                st.session_state["index_pipeline"],
         | 
| 78 | 
             
                            )
         | 
| 79 | 
             
                        if index_results:
         | 
| 80 | 
            +
                            st.write(index_results)
         | 
    	
        linter.sh
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            python -m black app.py interface core
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 | 
             
            streamlit
         | 
| 2 | 
             
            streamlit_option_menu
         | 
| 3 | 
             
            farm-haystack
         | 
| 4 | 
            -
            pygraphviz
         | 
|  | 
|  | |
| 1 | 
             
            streamlit
         | 
| 2 | 
             
            streamlit_option_menu
         | 
| 3 | 
             
            farm-haystack
         | 
| 4 | 
            +
            pygraphviz
         | 
| 5 | 
            +
            black
         | 
