import pandas as pd
import plotly.graph_objs as go
def visualize(cluster_data):
# Create a DataFrame with counts of each category
category_counts = {}
for paper in cluster_data:
categories = paper["categories"]
for category in categories:
if category in category_counts:
category_counts[category] += 1
else:
category_counts[category] = 1
category_df = pd.DataFrame(
{
"Category": list(category_counts.keys()),
"Count": list(category_counts.values()),
}
)
# Sort the DataFrame by count in descending order
category_df = category_df.sort_values(by="Count", ascending=False)
# Create hover text containing the count and titles of all papers for each category
hover_text = []
for category in category_df["Category"]:
titles = []
for paper in cluster_data:
if category in paper["categories"]:
titles.append(f'{paper["paper"]}')
hover_text.append(f'
Papers:
{"
".join(titles)}')
# Create Plotly Bar chart
fig = go.Figure(
data=[
go.Bar(
x=category_df["Category"],
y=category_df["Count"],
hovertext=hover_text,
marker=dict(color="brown"),
)
]
)
# Update layout
fig.update_layout(title="Categories", xaxis_title="Category", yaxis_title="Count")
return fig