pavanmutha commited on
Commit
8f39adc
·
verified ·
1 Parent(s): 15a30cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -251
app.py CHANGED
@@ -8,16 +8,13 @@ import time
8
  import psutil
9
  import optuna
10
  import ast
11
- import shap
12
- import lime
13
- import lime.lime_tabular
14
  import pandas as pd
15
- import numpy as np
16
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
17
  from sklearn.model_selection import train_test_split
18
  from sklearn.ensemble import RandomForestClassifier
19
- from sklearn.preprocessing import StandardScaler, PolynomialFeatures
20
- from sklearn.impute import SimpleImputer
 
 
21
  import matplotlib.pyplot as plt
22
 
23
  # Authenticate Hugging Face
@@ -27,23 +24,39 @@ login(token=hf_token, add_to_git_credential=True)
27
  # Initialize Model
28
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def format_observations(observations):
31
- if not isinstance(observations, dict):
32
- return f"<pre>{str(observations)}</pre>"
33
-
34
  return '\n'.join([
35
  f"""
36
  <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
37
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
38
  <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
39
  </div>
40
- """ for key, value in observations.items()
41
  ])
42
 
43
  def format_insights(insights, visuals):
44
- if not isinstance(insights, dict):
45
- return f"<pre>{str(insights)}</pre>"
46
-
47
  return '\n'.join([
48
  f"""
49
  <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
@@ -56,214 +69,7 @@ def format_insights(insights, visuals):
56
  """ for idx, (key, insight) in enumerate(insights.items())
57
  ])
58
 
59
- def format_analysis_report(raw_output, visuals, metrics=None, explainability_plots=None, hyperparams=None):
60
- try:
61
- # Ensure we have a dictionary to work with
62
- if isinstance(raw_output, str):
63
- try:
64
- analysis_dict = ast.literal_eval(raw_output)
65
- except:
66
- analysis_dict = {'observations': {'raw_output': raw_output}, 'insights': {}}
67
- elif isinstance(raw_output, dict):
68
- analysis_dict = raw_output
69
- else:
70
- analysis_dict = {'observations': {'raw_output': str(raw_output)}, 'insights': {}}
71
-
72
- # Metrics section
73
- metrics_section = ""
74
- if metrics:
75
- metrics_section = f"""
76
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
77
- <h2 style="color: #2B547E;">📈 Model Performance Metrics</h2>
78
- <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
79
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
80
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">Accuracy</h3>
81
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('accuracy', 0):.2f}</p>
82
- </div>
83
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
84
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">Precision</h3>
85
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('precision', 0):.2f}</p>
86
- </div>
87
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
88
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">Recall</h3>
89
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('recall', 0):.2f}</p>
90
- </div>
91
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
92
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">F1 Score</h3>
93
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('f1', 0):.2f}</p>
94
- </div>
95
- </div>
96
- </div>
97
- """
98
-
99
- # Hyperparameters section
100
- hyperparams_section = ""
101
- if hyperparams:
102
- hyperparams_items = ''.join([
103
- f"""
104
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
105
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
106
- <p style="font-size: 18px; margin: 0;">{value}</p>
107
- </div>
108
- """ for key, value in hyperparams.items()
109
- ])
110
-
111
- hyperparams_section = f"""
112
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
113
- <h2 style="color: #2B547E;">⚙️ Model Hyperparameters</h2>
114
- <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
115
- {hyperparams_items}
116
- </div>
117
- </div>
118
- """
119
-
120
- # Explainability section
121
- explainability_section = ""
122
- if explainability_plots:
123
- explainability_section = f"""
124
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
125
- <h2 style="color: #2B547E;">🔍 Model Explainability</h2>
126
- <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
127
- {''.join([f'<img src="/file={plot}" style="max-width: 100%; height: auto; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' for plot in explainability_plots])}
128
- </div>
129
- </div>
130
- """
131
-
132
- # Observations section
133
- observations_section = ""
134
- if 'observations' in analysis_dict:
135
- observations_section = f"""
136
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
137
- <h2 style="color: #2B547E;">🔍 Key Observations</h2>
138
- {format_observations(analysis_dict['observations'])}
139
- </div>
140
- """
141
-
142
- # Insights section
143
- insights_section = ""
144
- if 'insights' in analysis_dict:
145
- insights_section = f"""
146
- <div style="margin-top: 30px;">
147
- <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
148
- {format_insights(analysis_dict.get('insights', {}), visuals)}
149
- </div>
150
- """
151
-
152
- # Build the complete report
153
- report = f"""
154
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
155
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
156
- {hyperparams_section}
157
- {metrics_section}
158
- {explainability_section}
159
- {observations_section}
160
- {insights_section}
161
- </div>
162
- """
163
-
164
- return report, visuals
165
-
166
- except Exception as e:
167
- error_report = f"""
168
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
169
- <h1 style="color: #B22222;">⚠️ Error Generating Report</h1>
170
- <p>An error occurred while generating the report:</p>
171
- <pre style="background: #f8f9fa; padding: 10px; border-radius: 4px;">{str(e)}</pre>
172
- <p>Raw output:</p>
173
- <pre style="background: #f8f9fa; padding: 10px; border-radius: 4px;">{str(raw_output)}</pre>
174
- </div>
175
- """
176
- return error_report, visuals
177
-
178
- def preprocess_data(df, feature_engineering=True):
179
- """Handle missing values, categorical encoding, and feature engineering"""
180
- # Make a copy to avoid modifying the original
181
- df = df.copy()
182
-
183
- # Basic preprocessing - handle missing values
184
- numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
185
- if len(numeric_cols) > 0:
186
- imputer = SimpleImputer(strategy='median')
187
- df[numeric_cols] = imputer.fit_transform(df[numeric_cols])
188
-
189
- # Convert categorical variables if any
190
- categorical_cols = df.select_dtypes(include=['object']).columns
191
- for col in categorical_cols:
192
- if len(df[col].unique()) <= 10: # One-hot encode if few categories
193
- df = pd.concat([df, pd.get_dummies(df[col], prefix=col)], axis=1)
194
- df = df.drop(col, axis=1)
195
- else: # Otherwise just drop (or could use target encoding)
196
- df = df.drop(col, axis=1)
197
-
198
- # Feature engineering
199
- if feature_engineering and len(numeric_cols) > 0:
200
- # Create polynomial features for numerical columns
201
- poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
202
- poly_features = poly.fit_transform(df[numeric_cols])
203
- poly_cols = [f"poly_{i}" for i in range(poly_features.shape[1])]
204
- poly_df = pd.DataFrame(poly_features, columns=poly_cols)
205
- df = pd.concat([df, poly_df], axis=1)
206
-
207
- return df
208
-
209
- def evaluate_model(X, y, model, test_size=0.2):
210
- """Evaluate model performance with various metrics"""
211
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
212
-
213
- # Standardize features
214
- scaler = StandardScaler()
215
- X_train = scaler.fit_transform(X_train)
216
- X_test = scaler.transform(X_test)
217
-
218
- model.fit(X_train, y_train)
219
- y_pred = model.predict(X_test)
220
-
221
- return {
222
- 'accuracy': accuracy_score(y_test, y_pred),
223
- 'precision': precision_score(y_test, y_pred, average='weighted'),
224
- 'recall': recall_score(y_test, y_pred, average='weighted'),
225
- 'f1': f1_score(y_test, y_pred, average='weighted')
226
- }
227
-
228
- def generate_explainability_plots(X, model, feature_names, output_dir='./figures'):
229
- """Generate SHAP and LIME explainability plots"""
230
- os.makedirs(output_dir, exist_ok=True)
231
- plot_paths = []
232
-
233
- try:
234
- # SHAP Analysis
235
- explainer = shap.Explainer(model)
236
- shap_values = explainer(X[:100]) # Use first 100 samples for speed
237
-
238
- plt.figure()
239
- shap.summary_plot(shap_values, X[:100], feature_names=feature_names, show=False)
240
- shap_path = os.path.join(output_dir, 'shap_summary.png')
241
- plt.savefig(shap_path, bbox_inches='tight')
242
- plt.close()
243
- plot_paths.append(shap_path)
244
-
245
- # LIME Analysis
246
- explainer = lime.lime_tabular.LimeTabularExplainer(
247
- X,
248
- feature_names=feature_names,
249
- class_names=[str(x) for x in np.unique(model.classes_)],
250
- verbose=False,
251
- mode='classification'
252
- )
253
-
254
- # Explain a random instance
255
- exp = explainer.explain_instance(X[0], model.predict_proba, num_features=5)
256
- lime_path = os.path.join(output_dir, 'lime_explanation.png')
257
- exp.as_pyplot_figure().savefig(lime_path, bbox_inches='tight')
258
- plt.close()
259
- plot_paths.append(lime_path)
260
-
261
- except Exception as e:
262
- print(f"Explainability failed: {str(e)}")
263
-
264
- return plot_paths
265
-
266
- def analyze_data(csv_file, additional_notes="", perform_ml=True):
267
  start_time = time.time()
268
  process = psutil.Process(os.getpid())
269
  initial_memory = process.memory_info().rss / 1024 ** 2
@@ -276,35 +82,105 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
276
  run = wandb.init(project="huggingface-data-analysis", config={
277
  "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
278
  "additional_notes": additional_notes,
279
- "source_file": csv_file.name if csv_file else None,
280
- "perform_ml": perform_ml
281
  })
282
 
283
- metrics = None
284
- explainability_plots = None
285
- hyperparams = None
 
 
 
 
 
 
286
 
287
- try:
288
- # Load and preprocess data
289
- df = pd.read_csv(csv_file)
290
-
291
- if perform_ml and len(df.columns) > 1:
292
- try:
293
- processed_df = preprocess_data(df)
294
-
295
- # Assume last column is target for demonstration
296
- if len(processed_df.columns) > 1: # Ensure we still have features after preprocessing
297
- X = processed_df.iloc[:, :-1].values
298
- y = processed_df.iloc[:, -1].values
299
-
300
- # Convert y to numeric if needed
301
- if y.dtype == object:
302
- y = pd.factorize(y)[0]
303
-
304
- # Define model hyperparameters
305
- hyperparams = {
306
- 'n_estimators': 100,
307
- 'max_depth': None,
308
- 'min_samples_split': 2,
309
- 'min_samples_leaf': 1,
310
- 'max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import psutil
9
  import optuna
10
  import ast
 
 
 
11
  import pandas as pd
 
 
12
  from sklearn.model_selection import train_test_split
13
  from sklearn.ensemble import RandomForestClassifier
14
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
+ import shap
16
+ import lime
17
+ import lime.lime_tabular
18
  import matplotlib.pyplot as plt
19
 
20
  # Authenticate Hugging Face
 
24
  # Initialize Model
25
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
26
 
27
+ def format_analysis_report(raw_output, visuals):
28
+ try:
29
+ analysis_dict = raw_output if isinstance(raw_output, dict) else ast.literal_eval(str(raw_output))
30
+
31
+ report = f"""
32
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
33
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
34
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
35
+ <h2 style="color: #2B547E;">🔍 Key Observations</h2>
36
+ {format_observations(analysis_dict.get('observations', {}))}
37
+ </div>
38
+ <div style="margin-top: 30px;">
39
+ <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
40
+ {format_insights(analysis_dict.get('insights', {}), visuals)}
41
+ </div>
42
+ </div>
43
+ """
44
+ return report, visuals
45
+ except Exception as e:
46
+ print(f"Error formatting analysis report: {e}")
47
+ return str(raw_output), visuals
48
+
49
  def format_observations(observations):
 
 
 
50
  return '\n'.join([
51
  f"""
52
  <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
53
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
54
  <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
55
  </div>
56
+ """ for key, value in observations.items() if 'proportions' in key
57
  ])
58
 
59
  def format_insights(insights, visuals):
 
 
 
60
  return '\n'.join([
61
  f"""
62
  <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
 
69
  """ for idx, (key, insight) in enumerate(insights.items())
70
  ])
71
 
72
+ def analyze_data(csv_file, additional_notes=""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  start_time = time.time()
74
  process = psutil.Process(os.getpid())
75
  initial_memory = process.memory_info().rss / 1024 ** 2
 
82
  run = wandb.init(project="huggingface-data-analysis", config={
83
  "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
84
  "additional_notes": additional_notes,
85
+ "source_file": csv_file.name if csv_file else None
 
86
  })
87
 
88
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn"])
89
+ analysis_result = agent.run("""
90
+ You are an expert data analyst. Perform comprehensive analysis including:
91
+ 1. Basic statistics and data quality checks
92
+ 2. 3 insightful analytical questions about relationships in the data
93
+ 3. Visualization of key patterns and correlations
94
+ 4. Actionable real-world insights derived from findings
95
+ Generate publication-quality visualizations and save to './figures/'
96
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
97
 
98
+ execution_time = time.time() - start_time
99
+ final_memory = process.memory_info().rss / 1024 ** 2
100
+ memory_usage = final_memory - initial_memory
101
+ wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
102
+
103
+ visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
104
+ for viz in visuals:
105
+ wandb.log({os.path.basename(viz): wandb.Image(viz)})
106
+
107
+ run.finish()
108
+ return format_analysis_report(analysis_result, visuals)
109
+
110
+ def objective(trial, X_train, y_train, X_test, y_test):
111
+ n_estimators = trial.suggest_int("n_estimators", 50, 200)
112
+ max_depth = trial.suggest_int("max_depth", 3, 10)
113
+
114
+ model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
115
+ model.fit(X_train, y_train)
116
+ predictions = model.predict(X_test)
117
+
118
+ accuracy = accuracy_score(y_test, predictions)
119
+ return accuracy
120
+
121
+ def tune_hyperparameters(csv_file, n_trials: int):
122
+ df = pd.read_csv(csv_file)
123
+ y = df.iloc[:, -1]
124
+ X = df.iloc[:, :-1]
125
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
126
+
127
+ study = optuna.create_study(direction="maximize")
128
+ objective_func = lambda trial: objective(trial, X_train, y_train, X_test, y_test)
129
+ study.optimize(objective_func, n_trials=n_trials)
130
+
131
+ best_params = study.best_params
132
+ best_value = study.best_value
133
+
134
+ model = RandomForestClassifier(**best_params, random_state=42)
135
+ model.fit(X_train, y_train)
136
+ predictions = model.predict(X_test)
137
+
138
+ accuracy = accuracy_score(y_test, predictions)
139
+ precision = precision_score(y_test, predictions, average='weighted', zero_division=0)
140
+ recall = recall_score(y_test, predictions, average='weighted', zero_division=0)
141
+ f1 = f1_score(y_test, predictions, average='weighted', zero_division=0)
142
+
143
+ wandb.log({
144
+ "best_params": best_params,
145
+ "accuracy": accuracy,
146
+ "precision": precision,
147
+ "recall": recall,
148
+ "f1": f1,
149
+ })
150
+
151
+ shap_explainer = shap.TreeExplainer(model)
152
+ shap_values = shap_explainer.shap_values(X_test)
153
+ shap.summary_plot(shap_values, X_test, show=False)
154
+ shap_fig_path = "./figures/shap_summary.png"
155
+ plt.savefig(shap_fig_path)
156
+ wandb.log({"shap_summary": wandb.Image(shap_fig_path)})
157
+ plt.clf() #Clear figure to avoid plot overlap.
158
+
159
+ lime_explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns, class_names=['target'], mode='classification')
160
+ lime_explanation = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
161
+ lime_fig = lime_explanation.as_pyplot_figure()
162
+ lime_fig_path = "./figures/lime_explanation.png"
163
+ lime_fig.savefig(lime_fig_path)
164
+ wandb.log({"lime_explanation": wandb.Image(lime_fig_path)})
165
+ plt.clf() #Clear figure to avoid plot overlap.
166
+
167
+ return f"Best Hyperparameters: {best_params}<br>Accuracy: {accuracy}<br>Precision: {precision}<br>Recall: {recall}<br>F1-score: {f1}"
168
+
169
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
+ gr.Markdown("## 📊 AI Data Analysis Agent with Hyperparameter Optimization")
171
+ with gr.Row():
172
+ with gr.Column():
173
+ file_input = gr.File(label="Upload CSV Dataset", type="filepath")
174
+ notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
175
+ analyze_btn = gr.Button("Analyze", variant="primary")
176
+ optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
177
+ tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
178
+ with gr.Column():
179
+ analysis_output = gr.Markdown("### Analysis results will appear here...")
180
+ optuna_output = gr.HTML(label="Hyperparameter Tuning Results")
181
+ gallery = gr.Gallery(label="Data Visualizations", columns=2)
182
+
183
+ analyze_btn.click(fn=analyze_data, inputs=[file_input, notes_input], outputs=[analysis_output, gallery])
184
+ tune_btn.click(fn=tune_hyperparameters, inputs=[file_input, optuna_trials], outputs=[optuna_output])
185
+
186
+ demo.launch(debug=True)