import pandas as pd import plotly.graph_objs as go def visualize(cluster_data): # Create a DataFrame with counts of each category category_counts = {} for paper in cluster_data: categories = paper["categories"] for category in categories: if category in category_counts: category_counts[category] += 1 else: category_counts[category] = 1 category_df = pd.DataFrame( { "Category": list(category_counts.keys()), "Count": list(category_counts.values()), } ) # Sort the DataFrame by count in descending order category_df = category_df.sort_values(by="Count", ascending=False) # Create hover text containing the count and titles of all papers for each category hover_text = [] for category in category_df["Category"]: titles = [] for paper in cluster_data: if category in paper["categories"]: titles.append(f'{paper["paper"]}') hover_text.append(f'
Papers:
{"
".join(titles)}') # Create Plotly Bar chart fig = go.Figure( data=[ go.Bar( x=category_df["Category"], y=category_df["Count"], hovertext=hover_text, marker=dict(color="brown"), ) ] ) # Update layout fig.update_layout(title="Categories", xaxis_title="Category", yaxis_title="Count") return fig