|
import pandas as pd |
|
import plotly.graph_objs as go |
|
|
|
|
|
def visualize(cluster_data): |
|
|
|
|
|
category_counts = {} |
|
for paper in cluster_data: |
|
categories = paper["categories"] |
|
for category in categories: |
|
if category in category_counts: |
|
category_counts[category] += 1 |
|
else: |
|
category_counts[category] = 1 |
|
|
|
category_df = pd.DataFrame( |
|
{ |
|
"Category": list(category_counts.keys()), |
|
"Count": list(category_counts.values()), |
|
} |
|
) |
|
|
|
|
|
category_df = category_df.sort_values(by="Count", ascending=False) |
|
|
|
|
|
hover_text = [] |
|
for category in category_df["Category"]: |
|
titles = [] |
|
for paper in cluster_data: |
|
if category in paper["categories"]: |
|
titles.append(f'<a href="https://plot.ly/">{paper["paper"]}</a>') |
|
hover_text.append(f'<br>Papers:<br>{"<br>".join(titles)}') |
|
|
|
|
|
fig = go.Figure( |
|
data=[ |
|
go.Bar( |
|
x=category_df["Category"], |
|
y=category_df["Count"], |
|
hovertext=hover_text, |
|
marker=dict(color="brown"), |
|
) |
|
] |
|
) |
|
|
|
|
|
fig.update_layout(title="Categories", xaxis_title="Category", yaxis_title="Count") |
|
|
|
return fig |
|
|