from typing import Dict import streamlit as st import pandas as pd import numpy as np import plotly.express as px import hypernetx as hnx import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from io import BytesIO import time from utils.data_processor import load_data, process_data, build_hyperedges from utils.visualizer import visualize_gmm, visualize_ratings from utils.streamlit_hypergraph import hypergraph_visualization_component def main(): st.title("NeurIPS 2024 Bench Paper 高斯混合聚类分析") # 自动播放 slider_max = 10 if 'play_state' not in st.session_state: st.session_state.play_state = False if 'iteration' not in st.session_state: st.session_state.iteration = 0 # 定义回调函数来切换播放状态 def toggle_play(): if not st.session_state.play_state and st.session_state.iteration == slider_max: st.session_state.iteration = 0 # 重置迭代次数 st.session_state.play_state = not st.session_state.play_state # 创建播放/暂停按钮 if st.session_state.play_state: button_label = "暂停" else: button_label = "开始拟合" st.button(button_label, on_click=toggle_play, key="play_button") # 播放速度 # speed = st.slider("播放速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider") # 主页面布局 # 显示迭代次数滑条 iteration = st.slider("迭代步骤", min_value=1, max_value=slider_max, value=st.session_state.iteration, step=1, key="iteration_slider") # st.write(f"当前迭代次数: {iteration}") # print(st.session_state.iteration) # 动态限制采样数量的最大值 df = load_data() # 使用 sidebar 控制参数 with st.sidebar: st.header("控制面板") speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider") draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width") draw_height = st.slider("绘图高度",min_value=3, max_value=20, value=6, step=1, key="draw_height") max_samples = len(df) num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1) # 添加复选框选择显示 paper 的属性 display_attribute = st.selectbox( "选择显示 paper 的属性", ["order", "index", "id", "title", "keywords", "author"] ) # 选择是 top k 还是 top p display_option = st.selectbox( "选择显示的选项", ["Top K Clusters", "Clusters Up To Probability P"] ) # Top K Clusters if display_option == "Top K Clusters": max_k = 5 top_k = st.slider("选择 K 值", min_value=1, max_value=max_k, value=1, step=1) top_p = None else: top_k = None top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01) # 处理数据 sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples) # print(display_attribute) # 字符串 hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p) hypergraph = hnx.Hypergraph(hyperedges) # print(hyperedges) show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges") show_gaussian = st.checkbox("显示高斯分布", value=False, key="show_gaussian") if show_hypergraph: hypergraph_visualization_component(hypergraph, draw_width, draw_height) if show_gaussian: st.header("高斯混合分布聚类结果") fig_gmm = visualize_gmm(sampled_df, iteration) st.plotly_chart(fig_gmm, use_container_width=True) # 显示采样论文的详细信息 st.header("采样论文详细信息") st.dataframe(sampled_df[["index", "title", "keywords", "rating_avg", "confidence_avg", "site"] ] # .style.highlight_max(axis=0) ) # 增加第二种可视化方式 # st.header("论文评分分布") # fig_bar = visualize_ratings(sampled_df) # st.plotly_chart(fig_bar, use_container_width=True) # 自动播放功能 # print(st.session_state.play_state) if st.session_state.play_state: # 使用空容器来显示进度 progress_container = st.empty() with st.spinner("正在播放..."): if st.session_state.iteration < slider_max: # 增加滑动条值 st.session_state.iteration += 1 st.write(f"当前迭代次数: {st.session_state.iteration}") # print(st.session_state.iteration) # 等待一小段时间模拟滑动过程 time.sleep(1/speed) # 根据速度调整等待时间 # 使用rerun来更新页面 st.rerun() else: # 到达最大值时停止播放 st.session_state.play_state = False # if __name__ == "__main__": # # 设置页面布局 # st.set_page_config(layout="wide") # # 运行主函数 main()