Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
from astrapy import DataAPIClient | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import numpy as np | |
from openai import OpenAI | |
from typing import Dict, List | |
from dotenv import load_dotenv | |
import os | |
# Load environment variables | |
load_dotenv() | |
def initialize_client(): | |
try: | |
token = os.getenv("ASTRA_DB_TOKEN") | |
endpoint = os.getenv("ASTRA_DB_ENDPOINT") | |
if not token or not endpoint: | |
raise ValueError("AstraDB token or endpoint not found in environment variables.") | |
client = DataAPIClient(token) | |
db = client.get_database_by_api_endpoint(endpoint) | |
return db | |
except Exception as e: | |
st.error(f"Error initializing AstraDB client: {e}") | |
return None | |
def fetch_collection_data(db, collection_name): | |
try: | |
collection = db[collection_name] | |
documents = collection.find({}) | |
return list(documents) | |
except Exception as e: | |
st.error(f"Error fetching data from collection {collection_name}: {e}") | |
return None | |
def process_dataframe(data): | |
"""Cache the dataframe processing to prevent unnecessary recomputation""" | |
df = pd.DataFrame(data) | |
df = df.apply(pd.to_numeric, errors="ignore") | |
return df | |
def create_basic_visualization(df, viz_type, x_col, y_col, color_col=None): | |
"""Handle basic visualization types""" | |
if viz_type == "Line Chart": | |
fig = px.line(df, x=x_col, y=y_col, color=color_col, markers=True) | |
elif viz_type == "Bar Chart": | |
fig = px.bar(df, x=x_col, y=y_col, color=color_col, text=y_col) | |
elif viz_type == "Scatter Plot": | |
fig = px.scatter(df, x=x_col, y=y_col, color=color_col, size=y_col, hover_data=[color_col]) | |
elif viz_type == "Box Plot": | |
fig = px.box(df, x=x_col, y=y_col, color=color_col, points="all") | |
return fig | |
def create_advanced_visualization(df, viz_type, x_col, y_col, color_col=None): | |
if viz_type in ["Line Chart", "Bar Chart", "Scatter Plot", "Box Plot"]: | |
fig = create_basic_visualization(df, viz_type, x_col, y_col, color_col) | |
elif viz_type == "Engagement Sunburst": | |
total_engagement = df['likes'] + df['shares'] + df['comments'] | |
engagement_labels = pd.qcut(total_engagement, q=4, labels=['Low', 'Medium', 'High', 'Viral']) | |
temp_df = pd.DataFrame({ | |
'engagement_level': engagement_labels, | |
'post_type': df['post_type'], | |
'likes': df['likes'], | |
'sentiment': df['avg_sentiment_score'] | |
}) | |
fig = px.sunburst( | |
temp_df, | |
path=['engagement_level', 'post_type'], | |
values='likes', | |
color='sentiment', | |
color_continuous_scale='RdYlBu', | |
title="Engagement Distribution by Post Type and Sentiment" | |
) | |
elif viz_type == "Sentiment Heat Calendar": | |
# Create dummy datetime for visualization | |
hour_data = [] | |
days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday'] | |
for day in days: | |
for hour in range(24): | |
avg_sentiment = df['avg_sentiment_score'].mean() + np.random.normal(0, 0.1) | |
hour_data.append({ | |
'day': day, | |
'hour': hour, | |
'sentiment': avg_sentiment | |
}) | |
temp_df = pd.DataFrame(hour_data) | |
fig = px.density_heatmap( | |
temp_df, | |
x='day', | |
y='hour', | |
z='sentiment', | |
title="Sentiment Distribution by Day and Hour", | |
labels={'sentiment': 'Average Sentiment'}, | |
color_continuous_scale="RdYlBu" | |
) | |
elif viz_type == "Engagement Spider": | |
metrics = ['likes', 'shares', 'comments'] | |
df_norm = df[metrics].apply(lambda x: (x - x.min()) / (x.max() - x.min())) | |
fig = go.Figure() | |
for ptype in df['post_type'].unique(): | |
values = df_norm[df['post_type'] == ptype].mean() | |
fig.add_trace(go.Scatterpolar( | |
r=values.tolist() + [values.iloc[0]], | |
theta=metrics + [metrics[0]], | |
name=ptype, | |
fill='toself' | |
)) | |
fig.update_layout( | |
polar=dict(radialaxis=dict(visible=True, range=[0, 1])), | |
showlegend=True, | |
title="Engagement Pattern by Post Type" | |
) | |
elif viz_type == "Sentiment Flow": | |
# Group by post type and calculate rolling average | |
fig = go.Figure() | |
for ptype in df['post_type'].unique(): | |
mask = df['post_type'] == ptype | |
sentiment_series = df[mask]['avg_sentiment_score'] | |
rolling_avg = sentiment_series.rolling(window=min(7, len(sentiment_series))).mean() | |
fig.add_trace(go.Scatter( | |
x=list(range(len(rolling_avg))), # Use index instead of dates | |
y=rolling_avg, | |
name=ptype, | |
mode='lines', | |
fill='tonexty' | |
)) | |
fig.update_layout( | |
title="Sentiment Flow by Post Type", | |
xaxis_title="Post Sequence", | |
yaxis_title="Average Sentiment" | |
) | |
elif viz_type == "Engagement Matrix": | |
corr_matrix = df[['likes', 'shares', 'comments', 'avg_sentiment_score']].corr() | |
fig = px.imshow( | |
corr_matrix, | |
color_continuous_scale='RdBu', | |
aspect='auto', | |
title="Engagement Metrics Correlation Matrix" | |
) | |
# Apply theme | |
fig.update_layout( | |
template="plotly_dark" if st.session_state.dark_mode else "plotly_white", | |
title_x=0.5, | |
font=dict(size=14), | |
margin=dict(l=20, r=20, t=50, b=20), | |
paper_bgcolor="#1e1e1e" if st.session_state.dark_mode else "#f9f9f9", | |
plot_bgcolor="#1e1e1e" if st.session_state.dark_mode else "#f9f9f9", | |
) | |
return fig | |
def initialize_openai(): | |
"""Initialize OpenAI client""" | |
try: | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
return client | |
except Exception as e: | |
st.error(f"Error initializing OpenAI: {e}") | |
return None | |
def generate_prompt(metrics: Dict) -> str: | |
"""Generate a prompt for GPT based on the metrics""" | |
return f"""Analyze the following social media metrics and provide 3-5 clear, specific insights about post performance: | |
Post Type Metrics: | |
{metrics} | |
Please focus on: | |
1. Comparative performance between post types | |
2. Engagement patterns | |
3. Notable trends or anomalies | |
4. Actionable recommendations | |
Format your response in clear bullet points with percentage comparisons where relevant. | |
Keep each insight concise but specific, including numerical comparisons. | |
""" | |
def calculate_metrics(df: pd.DataFrame) -> Dict: | |
"""Calculate comprehensive metrics for GPT analysis""" | |
metrics = {} | |
# Calculate per post type metrics | |
for post_type in df['post_type'].unique(): | |
post_data = df[df['post_type'] == post_type] | |
metrics[post_type] = { | |
'avg_likes': post_data['likes'].mean(), | |
'avg_shares': post_data['shares'].mean(), | |
'avg_comments': post_data['comments'].mean(), | |
'avg_sentiment': post_data['avg_sentiment_score'].mean(), | |
'engagement_rate': (post_data['likes'] + post_data['shares'] + post_data['comments']).mean(), | |
'post_count': len(post_data) | |
} | |
# Calculate comparative metrics | |
total_posts = len(df) | |
total_engagement = df['likes'].sum() + df['shares'].sum() + df['comments'].sum() | |
metrics['overall'] = { | |
'total_posts': total_posts, | |
'total_engagement': total_engagement, | |
'avg_sentiment_overall': df['avg_sentiment_score'].mean() | |
} | |
return metrics | |
def get_gpt_insights(client: OpenAI, metrics: Dict, user_query: str) -> str: | |
"""Get insights from GPT based on the metrics and user query""" | |
try: | |
prompt = generate_prompt(metrics) + f"\n\nUser Query: {user_query}" | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a social media analytics expert. Provide clear, specific insights based on the data."}, | |
{"role": "user", "content": prompt} | |
], | |
temperature=0.7, | |
max_tokens=500 | |
) | |
# Extract and clean insights | |
insights_text = response.choices[0].message.content | |
return insights_text.strip() | |
except Exception as e: | |
return f"Error generating insights: {e}" | |
def main(): | |
st.set_page_config( | |
page_title="Advanced Social Media Analytics Dashboard", | |
page_icon="π", | |
layout="wide", | |
) | |
openai_client = initialize_openai() | |
# Sidebar Settings | |
with st.sidebar: | |
st.title("Dashboard Settings") | |
if "dark_mode" not in st.session_state: | |
st.session_state.dark_mode = False | |
st.checkbox("Dark Mode", value=st.session_state.dark_mode, key="dark_mode") | |
st.write("### Data Source") | |
st.info("Initializing connection to AstraDB...") | |
db = initialize_client() | |
if not db: | |
return | |
collections = db.list_collection_names() | |
st.success("Connected to AstraDB") | |
selected_collection = st.selectbox("Select Collection", collections) | |
if selected_collection: | |
data = fetch_collection_data(db, selected_collection) | |
if data: | |
# Use cached data processing | |
df = process_dataframe(data) | |
# Create tabs for different analysis views | |
tab1, tab2, tab3 = st.tabs(["π Visualizations", "π Metrics", "π€ AI Insights"]) | |
with tab1: | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.write("### Visualization Options") | |
viz_type = st.selectbox( | |
"Select Analysis Type", | |
[ | |
"Engagement Sunburst", | |
"Sentiment Heat Calendar", | |
"Engagement Spider", | |
"Sentiment Flow", | |
"Engagement Matrix", | |
"Line Chart", | |
"Bar Chart", | |
"Scatter Plot", | |
"Box Plot" | |
] | |
) | |
if viz_type in ["Line Chart", "Bar Chart", "Scatter Plot", "Box Plot"]: | |
x_col = st.selectbox("Select X-axis", df.columns) | |
y_col = st.selectbox("Select Y-axis", df.select_dtypes(include=["number"]).columns) | |
color_col = st.selectbox("Select Color Column (Optional)", [None] + list(df.columns), index=0) | |
else: | |
x_col = y_col = color_col = None | |
with col2: | |
try: | |
fig = create_advanced_visualization(df, viz_type, x_col, y_col, color_col) | |
st.plotly_chart(fig, use_container_width=True) | |
except Exception as e: | |
st.error(f"Error creating visualization: {e}") | |
with tab2: | |
# Display key metrics and insights | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Average Engagement Rate", | |
f"{((df['likes'] + df['shares'] + df['comments']).mean() / len(df)):.2f}") | |
st.metric("Likes Mean", f"{df['likes'].mean():.2f}") | |
st.metric("Shares Mean", f"{df['shares'].mean():.2f}") | |
st.metric("Comments Mean", f"{df['comments'].mean():.2f}") | |
st.metric("Max Likes", f"{df['likes'].max():.2f}") | |
st.metric("Min Likes", f"{df['likes'].min():.2f}") | |
with col2: | |
st.metric("Sentiment Trend", | |
f"{df['avg_sentiment_score'].mean():.2f}", | |
f"{df['avg_sentiment_score'].std():.2f}") | |
st.metric("Max Shares", f"{df['shares'].max():.2f}") | |
st.metric("Min Shares", f"{df['shares'].min():.2f}") | |
st.metric("Max Comments", f"{df['comments'].max():.2f}") | |
st.metric("Min Comments", f"{df['comments'].min():.2f}") | |
st.metric("Median Sentiment", f"{df['avg_sentiment_score'].median():.2f}") | |
with col3: | |
top_type = df.groupby('post_type')['likes'].sum().idxmax() | |
st.metric("Most Engaging Post Type", top_type) | |
with st.expander("Detailed Post Overview"): | |
st.markdown("**Detailed metrics for each post (ID, likes, shares, comments, sentiment):**") | |
if 'post_id' in df.columns: | |
st.dataframe(df[['post_id','likes','shares','comments','avg_sentiment_score']]) | |
else: | |
st.warning("No 'post_id' column found in the data.") | |
with tab3: | |
st.write("## AI Chatbot Insights") | |
if not openai_client: | |
st.error("OpenAI API not configured. Please add your API key to access AI insights.") | |
else: | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
user_input = st.text_input("Ask about data or insights:") | |
if st.button("Send"): | |
st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
# Use the modified get_gpt_insights function to generate response | |
metrics = calculate_metrics(df) | |
reply = get_gpt_insights(openai_client, metrics, user_input) | |
st.session_state.chat_history.append({"role": "assistant", "content": reply}) | |
for msg in st.session_state.chat_history: | |
if msg["role"] == "user": | |
st.markdown(f"**You:** {msg['content']}") | |
else: | |
st.markdown(f"**Assistant:** {msg['content']}") | |
else: | |
st.error("Failed to fetch data from the selected collection.") | |
else: | |
st.error("Please select a valid collection.") | |
if __name__ == "__main__": | |
main() |