|
import sys |
|
|
|
import gradio as gr |
|
|
|
sys.path.append(".") |
|
sys.path.append("..") |
|
sys.path.append("../..") |
|
|
|
from cluster import cluster |
|
from extract import extract_endpoint |
|
from generate_answers import generate_relevant_chunks |
|
|
|
queries = [ |
|
"What is the size, shape, and energy (watt hour) or capacity (Amp hour) of battery discussed in the paper?", |
|
"What specific mechanical testing methods were used to quantify strength?", |
|
"What parameters they used to quantify the benefit of their individual design (mass saving, increased run time, etc.)?", |
|
"What material chemistry combination (on the anode, cathode, separator, and electrolyte) was used in these papers?", |
|
"What kind of end use application they targeted?", |
|
] |
|
MAX_CATEGORIES = 10 |
|
|
|
|
|
def change_button(text): |
|
if len(text) > 0: |
|
return gr.Button(interactive=True) |
|
else: |
|
return gr.Button(interactive=False) |
|
|
|
|
|
def generate_category_btn(cluster_output): |
|
unique_categories = set() |
|
for item in cluster_output: |
|
unique_categories.update(item["categories"]) |
|
|
|
update_show = [gr.Button(visible=True, value=w) for w in unique_categories] |
|
update_hide = [ |
|
gr.Button(visible=False, value="") |
|
for _ in range(MAX_CATEGORIES - len(unique_categories)) |
|
] |
|
return update_show + update_hide |
|
|
|
|
|
def get_query(btn): |
|
return btn |
|
|
|
|
|
btn_list = [] |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown( |
|
""" |
|
# Paper Query Clustering + Visualization |
|
This app extracts text from papers and then searches for relevant excerpts based on a query. It then clusters and visualizes the relevant excerpts to find common themes across the papers. |
|
|
|
### Input |
|
1. A group of research papers that you want to run the query on. |
|
1. Query that you would like to know about these papers. |
|
|
|
### Output |
|
Clustering and visualization of the relevant excerpts which answer the query across the papers. |
|
|
|
# 1. Upload + Extract |
|
First, upload the papers you want to analyze. Currently, we only support PDFs. Once they're uploaded, you can extract the text data from the papers. |
|
""" |
|
) |
|
file_upload = gr.Files() |
|
extract_btn = gr.Button("Extract", interactive=False) |
|
with gr.Tab(label="Table"): |
|
extract_df = gr.Dataframe( |
|
datatype="markdown", column_widths=[100, 400], wrap=True |
|
) |
|
with gr.Tab(label="JSON"): |
|
extract_output = gr.JSON(label="Extract Output") |
|
|
|
gr.Markdown( |
|
""" |
|
---------------- |
|
# 2. Extract Relevant Excerpts |
|
Enter a query about these papers. This will search the papers to find the most relevant excerpts. |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### Input |
|
""" |
|
) |
|
query = gr.Textbox( |
|
label="Query", value=queries[1], lines=3, placeholder="Enter a query" |
|
) |
|
gr.Markdown( |
|
""" |
|
You can also select some example queries below. |
|
""" |
|
) |
|
with gr.Row(): |
|
q0_btn = gr.Button(queries[0]) |
|
q1_btn = gr.Button(queries[1]) |
|
q2_btn = gr.Button(queries[2]) |
|
q3_btn = gr.Button(queries[3]) |
|
q4_btn = gr.Button(queries[4]) |
|
gr.Markdown( |
|
""" |
|
---- |
|
""" |
|
) |
|
relevant_btn = gr.Button("Extract Excerpts", interactive=False) |
|
gr.Markdown( |
|
""" |
|
### Output |
|
""" |
|
) |
|
with gr.Tab(label="Output Table"): |
|
relevant_df = gr.Dataframe( |
|
datatype="markdown", column_widths=[100, 100, 300], wrap=True |
|
) |
|
with gr.Tab(label="JSON"): |
|
relevant_output = gr.JSON(label="Relevant Chunks Output") |
|
|
|
gr.Markdown( |
|
""" |
|
---------------- |
|
# 3. Cluster & Visualize |
|
Cluster the relevant excerpts to find common themes and visualize the results. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
### Input |
|
""" |
|
) |
|
cluster_btn = gr.Button("Cluster", interactive=False) |
|
cluster_output = gr.JSON(label="Cluster Output", visible=False) |
|
|
|
gr.Markdown( |
|
""" |
|
### Visualization |
|
""" |
|
) |
|
visualize_output = gr.Plot() |
|
with gr.Row(): |
|
for i in range(MAX_CATEGORIES): |
|
btn = gr.Button(visible=False) |
|
btn_list.append(btn) |
|
with gr.Tab(label="By Paper"): |
|
cluster_df = gr.Dataframe( |
|
datatype="markdown", column_widths=[100, 100, 300], wrap=True |
|
) |
|
|
|
with gr.Tab(label="By Excerpt"): |
|
cluster_granular_df = gr.Dataframe( |
|
datatype="markdown", column_widths=[100, 100, 300], wrap=True |
|
) |
|
|
|
|
|
file_upload.change(fn=change_button, inputs=[file_upload], outputs=[extract_btn]) |
|
|
|
extract_btn.click( |
|
fn=extract_endpoint, |
|
inputs=[file_upload], |
|
outputs=[extract_output, extract_df], |
|
) |
|
|
|
extract_output.change( |
|
fn=change_button, |
|
inputs=[extract_output], |
|
outputs=[relevant_btn], |
|
) |
|
|
|
q0_btn.click( |
|
fn=get_query, |
|
inputs=[q0_btn], |
|
outputs=[query], |
|
) |
|
|
|
q1_btn.click( |
|
fn=get_query, |
|
inputs=[q1_btn], |
|
outputs=[query], |
|
) |
|
|
|
q2_btn.click( |
|
fn=get_query, |
|
inputs=[q2_btn], |
|
outputs=[query], |
|
) |
|
|
|
q3_btn.click( |
|
fn=get_query, |
|
inputs=[q3_btn], |
|
outputs=[query], |
|
) |
|
|
|
q4_btn.click( |
|
fn=get_query, |
|
inputs=[q4_btn], |
|
outputs=[query], |
|
) |
|
|
|
relevant_btn.click( |
|
fn=generate_relevant_chunks, |
|
inputs=[query, extract_output], |
|
outputs=[relevant_output, relevant_df], |
|
api_name="relevant_chunks", |
|
) |
|
|
|
relevant_output.change( |
|
fn=change_button, inputs=[relevant_output], outputs=[cluster_btn] |
|
) |
|
|
|
cluster_btn.click( |
|
fn=cluster, |
|
inputs=[query, relevant_output], |
|
outputs=[cluster_output, cluster_df, visualize_output, cluster_granular_df], |
|
api_name="cluster", |
|
) |
|
|
|
cluster_output.change( |
|
fn=generate_category_btn, |
|
inputs=[cluster_output], |
|
outputs=btn_list, |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|