vvinayakkk commited on
Commit
1d88eef
Β·
1 Parent(s): 4590034

all committed

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. requirements.txt +6 -0
  3. tp.py +376 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ *pyc
3
+ /myenv/
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit==1.20.0
2
+ pandas==1.5.3
3
+ plotly==5.11.0
4
+ numpy==1.24.1
5
+ astrapy==0.6.1
6
+ openai==0.27.0
tp.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ from astrapy import DataAPIClient
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ import numpy as np
8
+ from openai import OpenAI
9
+ from typing import Dict, List
10
+ from dotenv import load_dotenv
11
+ import os
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ def initialize_client():
17
+ try:
18
+ token = os.getenv("ASTRA_DB_TOKEN")
19
+ endpoint = os.getenv("ASTRA_DB_ENDPOINT")
20
+
21
+ if not token or not endpoint:
22
+ raise ValueError("AstraDB token or endpoint not found in environment variables.")
23
+
24
+ client = DataAPIClient(token)
25
+ db = client.get_database_by_api_endpoint(endpoint)
26
+ return db
27
+ except Exception as e:
28
+ st.error(f"Error initializing AstraDB client: {e}")
29
+ return None
30
+
31
+ def fetch_collection_data(db, collection_name):
32
+ try:
33
+ collection = db[collection_name]
34
+ documents = collection.find({})
35
+ return list(documents)
36
+ except Exception as e:
37
+ st.error(f"Error fetching data from collection {collection_name}: {e}")
38
+ return None
39
+
40
+ @st.cache_data
41
+ def process_dataframe(data):
42
+ """Cache the dataframe processing to prevent unnecessary recomputation"""
43
+ df = pd.DataFrame(data)
44
+ df = df.apply(pd.to_numeric, errors="ignore")
45
+ return df
46
+
47
+ def create_basic_visualization(df, viz_type, x_col, y_col, color_col=None):
48
+ """Handle basic visualization types"""
49
+ if viz_type == "Line Chart":
50
+ fig = px.line(df, x=x_col, y=y_col, color=color_col, markers=True)
51
+ elif viz_type == "Bar Chart":
52
+ fig = px.bar(df, x=x_col, y=y_col, color=color_col, text=y_col)
53
+ elif viz_type == "Scatter Plot":
54
+ fig = px.scatter(df, x=x_col, y=y_col, color=color_col, size=y_col, hover_data=[color_col])
55
+ elif viz_type == "Box Plot":
56
+ fig = px.box(df, x=x_col, y=y_col, color=color_col, points="all")
57
+ return fig
58
+
59
+ def create_advanced_visualization(df, viz_type, x_col, y_col, color_col=None):
60
+ if viz_type in ["Line Chart", "Bar Chart", "Scatter Plot", "Box Plot"]:
61
+ fig = create_basic_visualization(df, viz_type, x_col, y_col, color_col)
62
+
63
+ elif viz_type == "Engagement Sunburst":
64
+ total_engagement = df['likes'] + df['shares'] + df['comments']
65
+ engagement_labels = pd.qcut(total_engagement, q=4, labels=['Low', 'Medium', 'High', 'Viral'])
66
+ temp_df = pd.DataFrame({
67
+ 'engagement_level': engagement_labels,
68
+ 'post_type': df['post_type'],
69
+ 'likes': df['likes'],
70
+ 'sentiment': df['avg_sentiment_score']
71
+ })
72
+
73
+ fig = px.sunburst(
74
+ temp_df,
75
+ path=['engagement_level', 'post_type'],
76
+ values='likes',
77
+ color='sentiment',
78
+ color_continuous_scale='RdYlBu',
79
+ title="Engagement Distribution by Post Type and Sentiment"
80
+ )
81
+
82
+ elif viz_type == "Sentiment Heat Calendar":
83
+ # Create dummy datetime for visualization
84
+ hour_data = []
85
+ days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
86
+
87
+ for day in days:
88
+ for hour in range(24):
89
+ avg_sentiment = df['avg_sentiment_score'].mean() + np.random.normal(0, 0.1)
90
+ hour_data.append({
91
+ 'day': day,
92
+ 'hour': hour,
93
+ 'sentiment': avg_sentiment
94
+ })
95
+
96
+ temp_df = pd.DataFrame(hour_data)
97
+ fig = px.density_heatmap(
98
+ temp_df,
99
+ x='day',
100
+ y='hour',
101
+ z='sentiment',
102
+ title="Sentiment Distribution by Day and Hour",
103
+ labels={'sentiment': 'Average Sentiment'},
104
+ color_continuous_scale="RdYlBu"
105
+ )
106
+
107
+ elif viz_type == "Engagement Spider":
108
+ metrics = ['likes', 'shares', 'comments']
109
+ df_norm = df[metrics].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
110
+
111
+ fig = go.Figure()
112
+ for ptype in df['post_type'].unique():
113
+ values = df_norm[df['post_type'] == ptype].mean()
114
+ fig.add_trace(go.Scatterpolar(
115
+ r=values.tolist() + [values.iloc[0]],
116
+ theta=metrics + [metrics[0]],
117
+ name=ptype,
118
+ fill='toself'
119
+ ))
120
+
121
+ fig.update_layout(
122
+ polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
123
+ showlegend=True,
124
+ title="Engagement Pattern by Post Type"
125
+ )
126
+
127
+ elif viz_type == "Sentiment Flow":
128
+ # Group by post type and calculate rolling average
129
+ fig = go.Figure()
130
+ for ptype in df['post_type'].unique():
131
+ mask = df['post_type'] == ptype
132
+ sentiment_series = df[mask]['avg_sentiment_score']
133
+ rolling_avg = sentiment_series.rolling(window=min(7, len(sentiment_series))).mean()
134
+
135
+ fig.add_trace(go.Scatter(
136
+ x=list(range(len(rolling_avg))), # Use index instead of dates
137
+ y=rolling_avg,
138
+ name=ptype,
139
+ mode='lines',
140
+ fill='tonexty'
141
+ ))
142
+
143
+ fig.update_layout(
144
+ title="Sentiment Flow by Post Type",
145
+ xaxis_title="Post Sequence",
146
+ yaxis_title="Average Sentiment"
147
+ )
148
+
149
+ elif viz_type == "Engagement Matrix":
150
+ corr_matrix = df[['likes', 'shares', 'comments', 'avg_sentiment_score']].corr()
151
+
152
+ fig = px.imshow(
153
+ corr_matrix,
154
+ color_continuous_scale='RdBu',
155
+ aspect='auto',
156
+ title="Engagement Metrics Correlation Matrix"
157
+ )
158
+
159
+ # Apply theme
160
+ fig.update_layout(
161
+ template="plotly_dark" if st.session_state.dark_mode else "plotly_white",
162
+ title_x=0.5,
163
+ font=dict(size=14),
164
+ margin=dict(l=20, r=20, t=50, b=20),
165
+ paper_bgcolor="#1e1e1e" if st.session_state.dark_mode else "#f9f9f9",
166
+ plot_bgcolor="#1e1e1e" if st.session_state.dark_mode else "#f9f9f9",
167
+ )
168
+ return fig
169
+
170
+ def initialize_openai():
171
+ """Initialize OpenAI client"""
172
+ try:
173
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
174
+ return client
175
+ except Exception as e:
176
+ st.error(f"Error initializing OpenAI: {e}")
177
+ return None
178
+
179
+ def generate_prompt(metrics: Dict) -> str:
180
+ """Generate a prompt for GPT based on the metrics"""
181
+ return f"""Analyze the following social media metrics and provide 3-5 clear, specific insights about post performance:
182
+
183
+ Post Type Metrics:
184
+ {metrics}
185
+
186
+ Please focus on:
187
+ 1. Comparative performance between post types
188
+ 2. Engagement patterns
189
+ 3. Notable trends or anomalies
190
+ 4. Actionable recommendations
191
+
192
+ Format your response in clear bullet points with percentage comparisons where relevant.
193
+ Keep each insight concise but specific, including numerical comparisons.
194
+ """
195
+
196
+ def calculate_metrics(df: pd.DataFrame) -> Dict:
197
+ """Calculate comprehensive metrics for GPT analysis"""
198
+ metrics = {}
199
+
200
+ # Calculate per post type metrics
201
+ for post_type in df['post_type'].unique():
202
+ post_data = df[df['post_type'] == post_type]
203
+ metrics[post_type] = {
204
+ 'avg_likes': post_data['likes'].mean(),
205
+ 'avg_shares': post_data['shares'].mean(),
206
+ 'avg_comments': post_data['comments'].mean(),
207
+ 'avg_sentiment': post_data['avg_sentiment_score'].mean(),
208
+ 'engagement_rate': (post_data['likes'] + post_data['shares'] + post_data['comments']).mean(),
209
+ 'post_count': len(post_data)
210
+ }
211
+
212
+ # Calculate comparative metrics
213
+ total_posts = len(df)
214
+ total_engagement = df['likes'].sum() + df['shares'].sum() + df['comments'].sum()
215
+
216
+ metrics['overall'] = {
217
+ 'total_posts': total_posts,
218
+ 'total_engagement': total_engagement,
219
+ 'avg_sentiment_overall': df['avg_sentiment_score'].mean()
220
+ }
221
+
222
+ return metrics
223
+
224
+ def get_gpt_insights(client: OpenAI, metrics: Dict, user_query: str) -> str:
225
+ """Get insights from GPT based on the metrics and user query"""
226
+ try:
227
+ prompt = generate_prompt(metrics) + f"\n\nUser Query: {user_query}"
228
+
229
+ response = client.chat.completions.create(
230
+ model="gpt-3.5-turbo",
231
+ messages=[
232
+ {"role": "system", "content": "You are a social media analytics expert. Provide clear, specific insights based on the data."},
233
+ {"role": "user", "content": prompt}
234
+ ],
235
+ temperature=0.7,
236
+ max_tokens=500
237
+ )
238
+
239
+ # Extract and clean insights
240
+ insights_text = response.choices[0].message.content
241
+ return insights_text.strip()
242
+
243
+ except Exception as e:
244
+ return f"Error generating insights: {e}"
245
+
246
+ def main():
247
+ st.set_page_config(
248
+ page_title="Advanced Social Media Analytics Dashboard",
249
+ page_icon="πŸ“Š",
250
+ layout="wide",
251
+ )
252
+ openai_client = initialize_openai()
253
+
254
+ # Sidebar Settings
255
+ with st.sidebar:
256
+ st.title("Dashboard Settings")
257
+ if "dark_mode" not in st.session_state:
258
+ st.session_state.dark_mode = False
259
+ st.checkbox("Dark Mode", value=st.session_state.dark_mode, key="dark_mode")
260
+
261
+ st.write("### Data Source")
262
+ st.info("Initializing connection to AstraDB...")
263
+ db = initialize_client()
264
+ if not db:
265
+ return
266
+
267
+ collections = db.list_collection_names()
268
+ st.success("Connected to AstraDB")
269
+ selected_collection = st.selectbox("Select Collection", collections)
270
+
271
+ if selected_collection:
272
+ data = fetch_collection_data(db, selected_collection)
273
+ if data:
274
+ # Use cached data processing
275
+ df = process_dataframe(data)
276
+
277
+ # Create tabs for different analysis views
278
+ tab1, tab2, tab3 = st.tabs(["πŸ“Š Visualizations", "πŸ“ˆ Metrics", "πŸ€– AI Insights"])
279
+ with tab1:
280
+ col1, col2 = st.columns([1, 3])
281
+
282
+ with col1:
283
+ st.write("### Visualization Options")
284
+ viz_type = st.selectbox(
285
+ "Select Analysis Type",
286
+ [
287
+ "Engagement Sunburst",
288
+ "Sentiment Heat Calendar",
289
+ "Engagement Spider",
290
+ "Sentiment Flow",
291
+ "Engagement Matrix",
292
+ "Line Chart",
293
+ "Bar Chart",
294
+ "Scatter Plot",
295
+ "Box Plot"
296
+ ]
297
+ )
298
+
299
+ if viz_type in ["Line Chart", "Bar Chart", "Scatter Plot", "Box Plot"]:
300
+ x_col = st.selectbox("Select X-axis", df.columns)
301
+ y_col = st.selectbox("Select Y-axis", df.select_dtypes(include=["number"]).columns)
302
+ color_col = st.selectbox("Select Color Column (Optional)", [None] + list(df.columns), index=0)
303
+ else:
304
+ x_col = y_col = color_col = None
305
+
306
+ with col2:
307
+ try:
308
+ fig = create_advanced_visualization(df, viz_type, x_col, y_col, color_col)
309
+ st.plotly_chart(fig, use_container_width=True)
310
+ except Exception as e:
311
+ st.error(f"Error creating visualization: {e}")
312
+
313
+ with tab2:
314
+ # Display key metrics and insights
315
+ col1, col2, col3 = st.columns(3)
316
+
317
+ with col1:
318
+ st.metric("Average Engagement Rate",
319
+ f"{((df['likes'] + df['shares'] + df['comments']).mean() / len(df)):.2f}")
320
+ st.metric("Likes Mean", f"{df['likes'].mean():.2f}")
321
+ st.metric("Shares Mean", f"{df['shares'].mean():.2f}")
322
+ st.metric("Comments Mean", f"{df['comments'].mean():.2f}")
323
+ st.metric("Max Likes", f"{df['likes'].max():.2f}")
324
+ st.metric("Min Likes", f"{df['likes'].min():.2f}")
325
+
326
+ with col2:
327
+ st.metric("Sentiment Trend",
328
+ f"{df['avg_sentiment_score'].mean():.2f}",
329
+ f"{df['avg_sentiment_score'].std():.2f}")
330
+ st.metric("Max Shares", f"{df['shares'].max():.2f}")
331
+ st.metric("Min Shares", f"{df['shares'].min():.2f}")
332
+ st.metric("Max Comments", f"{df['comments'].max():.2f}")
333
+ st.metric("Min Comments", f"{df['comments'].min():.2f}")
334
+ st.metric("Median Sentiment", f"{df['avg_sentiment_score'].median():.2f}")
335
+
336
+ with col3:
337
+ top_type = df.groupby('post_type')['likes'].sum().idxmax()
338
+ st.metric("Most Engaging Post Type", top_type)
339
+
340
+ with st.expander("Detailed Post Overview"):
341
+ st.markdown("**Detailed metrics for each post (ID, likes, shares, comments, sentiment):**")
342
+ if 'post_id' in df.columns:
343
+ st.dataframe(df[['post_id','likes','shares','comments','avg_sentiment_score']])
344
+ else:
345
+ st.warning("No 'post_id' column found in the data.")
346
+
347
+ with tab3:
348
+ st.write("## AI Chatbot Insights")
349
+ if not openai_client:
350
+ st.error("OpenAI API not configured. Please add your API key to access AI insights.")
351
+ else:
352
+ if 'chat_history' not in st.session_state:
353
+ st.session_state.chat_history = []
354
+
355
+ user_input = st.text_input("Ask about data or insights:")
356
+ if st.button("Send"):
357
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
358
+
359
+ # Use the modified get_gpt_insights function to generate response
360
+ metrics = calculate_metrics(df)
361
+ reply = get_gpt_insights(openai_client, metrics, user_input)
362
+ st.session_state.chat_history.append({"role": "assistant", "content": reply})
363
+
364
+ for msg in st.session_state.chat_history:
365
+ if msg["role"] == "user":
366
+ st.markdown(f"**You:** {msg['content']}")
367
+ else:
368
+ st.markdown(f"**Assistant:** {msg['content']}")
369
+
370
+ else:
371
+ st.error("Failed to fetch data from the selected collection.")
372
+ else:
373
+ st.error("Please select a valid collection.")
374
+
375
+ if __name__ == "__main__":
376
+ main()