Sephfox commited on
Commit
ba908ff
Β·
verified Β·
1 Parent(s): ee8e3ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -234
app.py CHANGED
@@ -9,293 +9,365 @@ from huggingface_hub import HfApi
9
  import os
10
  import traceback
11
  from contextlib import contextmanager
 
 
 
 
 
 
12
 
13
- # Error Handling Context Manager
14
- @contextmanager
15
- def error_handling(operation_name):
16
- try:
17
- yield
18
- except Exception as e:
19
- error_msg = f"Error during {operation_name}: {str(e)}\n{traceback.format_exc()}"
20
- st.error(error_msg)
21
- with open("error_log.txt", "a") as f:
22
- f.write(f"\n{error_msg}")
23
-
24
- # Cyberpunk Styling
25
- def setup_cyberpunk_style():
26
  st.markdown("""
27
  <style>
28
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
 
29
 
30
  .stApp {
31
- background: linear-gradient(45deg, #000428, #004e92);
 
 
 
 
 
 
32
  }
33
 
34
  .main-title {
35
  font-family: 'Orbitron', sans-serif;
36
- color: #00ff9d;
 
 
37
  text-align: center;
38
- text-shadow: 0 0 10px #00ff9d;
39
- padding: 20px;
40
- font-size: 2.5em;
41
  margin-bottom: 30px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
43
 
44
  .stButton>button {
 
45
  background: linear-gradient(45deg, #00ff9d, #00b8ff);
46
  color: black;
47
- font-family: 'Orbitron', sans-serif;
48
  border: none;
49
- padding: 10px 20px;
50
  border-radius: 5px;
51
  text-transform: uppercase;
52
  font-weight: bold;
 
53
  transition: all 0.3s ease;
 
 
54
  }
55
 
56
  .stButton>button:hover {
57
  transform: scale(1.05);
58
- box-shadow: 0 0 15px #00ff9d;
59
  }
60
 
61
- .metric-container {
62
- background: rgba(0, 0, 0, 0.5);
63
- border: 2px solid #00ff9d;
64
- border-radius: 10px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  padding: 15px;
66
  margin: 10px 0;
 
67
  }
68
 
69
- .status-text {
70
- color: #00ff9d;
71
- font-family: 'Orbitron', sans-serif;
72
- font-size: 1.2em;
 
 
 
 
73
  }
74
 
75
- .sidebar .stSelectbox, .sidebar .stSlider {
76
- background-color: rgba(0, 0, 0, 0.3);
77
- border-radius: 5px;
78
- padding: 10px;
79
- margin: 5px 0;
80
  }
81
  </style>
82
  """, unsafe_allow_html=True)
83
 
84
- # Your existing functions with error handling
85
- def generate_demo_data(num_samples=60):
86
- with error_handling("demo data generation"):
87
- # Your existing generate_demo_data code
88
- subjects = [
89
- 'Artificial intelligence', 'Climate change', 'Renewable energy',
90
- 'Space exploration', 'Quantum computing', 'Genetic engineering',
91
- 'Blockchain technology', 'Virtual reality', 'Cybersecurity',
92
- 'Biotechnology', 'Nanotechnology', 'Astrophysics'
93
- ]
94
- verbs = [
95
- 'is transforming', 'is influencing', 'is revolutionizing',
96
- 'is challenging', 'is advancing', 'is reshaping', 'is impacting',
97
- 'is enhancing', 'is disrupting', 'is redefining'
98
- ]
99
- objects = [
100
- 'modern science', 'global economies', 'healthcare systems',
101
- 'communication methods', 'educational approaches',
102
- 'environmental policies', 'social interactions', 'the job market',
103
- 'data security', 'the entertainment industry'
104
- ]
105
- data = []
106
- for i in range(num_samples):
107
- subject = random.choice(subjects)
108
- verb = random.choice(verbs)
109
- obj = random.choice(objects)
110
- sentence = f"{subject} {verb} {obj}."
111
- data.append(sentence)
112
- return data
113
 
114
- def upload_to_huggingface(model_path, token, repo_name):
115
- with error_handling("HuggingFace upload"):
116
- api = HfApi()
117
- api.create_repo(repo_name, token=token, private=True)
118
- api.upload_folder(
119
- folder_path=model_path,
120
- repo_id=repo_name,
121
- token=token
122
  )
123
- return True
 
124
 
125
- def fitness_function(individual, train_dataset, model, tokenizer):
126
- with error_handling("fitness evaluation"):
127
- training_args = TrainingArguments(
128
- output_dir='./results',
129
- overwrite_output_dir=True,
130
- num_train_epochs=individual['epochs'],
131
- per_device_train_batch_size=individual['batch_size'],
132
- learning_rate=individual['learning_rate'],
133
- logging_steps=10,
134
- save_steps=10,
135
- save_total_limit=2,
136
- report_to='none',
137
- )
138
-
139
- data_collator = DataCollatorForLanguageModeling(
140
- tokenizer=tokenizer, mlm=False
141
- )
142
-
143
- trainer = Trainer(
144
- model=model,
145
- args=training_args,
146
- data_collator=data_collator,
147
- train_dataset=train_dataset,
148
- eval_dataset=None,
149
- )
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- trainer.train()
152
- logs = [log for log in trainer.state.log_history if 'loss' in log]
153
- return logs[-1]['loss'] if logs else float('inf')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def main():
156
- setup_cyberpunk_style()
157
 
158
  st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True)
159
-
160
- # Sidebar Configuration
 
 
 
161
  with st.sidebar:
162
- st.markdown("### 🌐 Configuration")
 
 
 
 
163
 
164
- hf_token = st.text_input("πŸ”‘ HuggingFace Token", type="password")
165
- repo_name = st.text_input("πŸ“ Repository Name", "my-gpt2-model")
166
 
167
- data_source = st.selectbox(
168
- 'πŸ“Š Data Source',
169
- ('DEMO', 'Upload Text File')
170
- )
171
 
172
- st.markdown("### βš™οΈ Evolution Parameters")
173
- population_size = st.slider("Population Size", 4, 20, 6)
174
- num_generations = st.slider("Generations", 1, 10, 3)
175
- num_parents = st.slider("Parents", 2, population_size, 2)
176
- mutation_rate = st.slider("Mutation Rate", 0.0, 1.0, 0.1)
177
-
178
- # Hyperparameter bounds
179
- param_bounds = {
180
- 'learning_rate': (1e-5, 5e-5),
181
- 'epochs': (1, 3),
182
- 'batch_size': [2, 4, 8]
183
- }
184
-
185
- # Main Content Area
186
- with error_handling("main application flow"):
187
- if data_source == 'DEMO':
188
- st.info("πŸ€– Using demo data...")
189
- data = generate_demo_data()
190
- else:
191
- uploaded_file = st.file_uploader("πŸ“‚ Upload Training Data", type="txt")
192
- if uploaded_file:
193
- data = load_data(uploaded_file)
194
- else:
195
- st.warning("⚠️ Please upload a text file")
196
- st.stop()
197
-
198
- # Model Setup
199
- with st.spinner("πŸ”§ Loading GPT-2..."):
200
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
201
- model = GPT2LMHeadModel.from_pretrained('gpt2')
202
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
- model.to(device)
204
- tokenizer.pad_token = tokenizer.eos_token
205
- model.config.pad_token_id = model.config.eos_token_id
206
-
207
- # Dataset Preparation
208
- with st.spinner("πŸ“Š Preparing dataset..."):
209
- train_dataset = prepare_dataset(data, tokenizer)
210
-
211
- if st.button("πŸš€ Start Training", key="start_training"):
212
- progress_bar = st.progress(0)
213
- status_text = st.empty()
214
 
215
- # Metrics Display
216
- col1, col2, col3 = st.columns(3)
217
- with col1:
218
- metrics_loss = st.empty()
219
- with col2:
220
- metrics_generation = st.empty()
221
- with col3:
222
- metrics_status = st.empty()
223
-
224
- try:
225
- # Initialize GA
226
- population = create_population(population_size, param_bounds)
227
- best_individual = None
228
- best_fitness = float('inf')
229
- fitness_history = []
230
-
231
- total_evaluations = num_generations * len(population)
232
- current_evaluation = 0
233
-
234
- for generation in range(num_generations):
235
- metrics_generation.markdown(f"""
236
- <div class="metric-container">
237
- <p class="status-text">Generation: {generation + 1}/{num_generations}</p>
238
- </div>
239
- """, unsafe_allow_html=True)
240
-
241
- fitnesses = []
242
- for idx, individual in enumerate(population):
243
- status_text.text(f"🧬 Evaluating individual {idx+1}/{len(population)} in generation {generation+1}")
244
-
245
- # Clone model for each individual
246
- model_clone = GPT2LMHeadModel.from_pretrained('gpt2')
247
- model_clone.to(device)
248
-
249
- fitness = fitness_function(individual, train_dataset, model_clone, tokenizer)
250
- fitnesses.append(fitness)
251
-
252
- if fitness < best_fitness:
253
- best_fitness = fitness
254
- best_individual = individual.copy()
255
-
256
- metrics_loss.markdown(f"""
257
- <div class="metric-container">
258
- <p class="status-text">Best Loss: {best_fitness:.4f}</p>
259
- </div>
260
- """, unsafe_allow_html=True)
261
-
262
- current_evaluation += 1
263
- progress_bar.progress(current_evaluation / total_evaluations)
264
-
265
- # Evolution steps
266
- parents = select_mating_pool(population, fitnesses, num_parents)
267
- offspring_size = population_size - num_parents
268
- offspring = crossover(parents, offspring_size)
269
- offspring = mutation(offspring, param_bounds, mutation_rate)
270
- population = parents + offspring
271
- fitness_history.append(min(fitnesses))
272
-
273
- # Training Complete
274
- st.success("πŸŽ‰ Training completed!")
275
- st.write("Best Hyperparameters:", best_individual)
276
- st.write("Best Fitness (Loss):", best_fitness)
277
-
278
- # Plot fitness history
279
- st.line_chart(fitness_history)
280
 
281
- # Save and Upload Model
282
- if st.button("πŸ’Ύ Save & Upload Model"):
283
- with st.spinner("Saving model..."):
284
- model.save_pretrained('./fine_tuned_model')
285
- tokenizer.save_pretrained('./fine_tuned_model')
286
-
287
- if hf_token:
288
- if upload_to_huggingface('./fine_tuned_model', hf_token, repo_name):
289
- st.success(f"βœ… Model uploaded to HuggingFace: {repo_name}")
290
- else:
291
- st.error("❌ Failed to upload model")
292
- else:
293
- st.warning("⚠️ No HuggingFace token provided. Model saved locally only.")
294
 
295
- except Exception as e:
296
- st.error(f"❌ Training error: {str(e)}")
297
- with open("error_log.txt", "a") as f:
298
- f.write(f"\nTraining error: {str(e)}\n{traceback.format_exc()}")
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  if __name__ == "__main__":
301
  main()
 
9
  import os
10
  import traceback
11
  from contextlib import contextmanager
12
+ import plotly.graph_objects as go
13
+ import plotly.express as px
14
+ from datetime import datetime
15
+ import time
16
+ import json
17
+ import pandas as pd
18
 
19
+ # Advanced Cyberpunk Styling
20
+ def setup_advanced_cyberpunk_style():
 
 
 
 
 
 
 
 
 
 
 
21
  st.markdown("""
22
  <style>
23
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
24
+ @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
25
 
26
  .stApp {
27
+ background: linear-gradient(
28
+ 45deg,
29
+ rgba(0, 0, 0, 0.9) 0%,
30
+ rgba(0, 30, 60, 0.9) 50%,
31
+ rgba(0, 0, 0, 0.9) 100%
32
+ );
33
+ color: #00ff9d;
34
  }
35
 
36
  .main-title {
37
  font-family: 'Orbitron', sans-serif;
38
+ background: linear-gradient(45deg, #00ff9d, #00b8ff);
39
+ -webkit-background-clip: text;
40
+ -webkit-text-fill-color: transparent;
41
  text-align: center;
42
+ font-size: 3.5em;
 
 
43
  margin-bottom: 30px;
44
+ text-transform: uppercase;
45
+ letter-spacing: 3px;
46
+ animation: glow 2s ease-in-out infinite alternate;
47
+ }
48
+
49
+ @keyframes glow {
50
+ from {
51
+ text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d, 0 0 15px #00ff9d;
52
+ }
53
+ to {
54
+ text-shadow: 0 0 10px #00b8ff, 0 0 20px #00b8ff, 0 0 30px #00b8ff;
55
+ }
56
+ }
57
+
58
+ .cyber-box {
59
+ background: rgba(0, 0, 0, 0.7);
60
+ border: 2px solid #00ff9d;
61
+ border-radius: 10px;
62
+ padding: 20px;
63
+ margin: 10px 0;
64
+ position: relative;
65
+ overflow: hidden;
66
+ }
67
+
68
+ .cyber-box::before {
69
+ content: '';
70
+ position: absolute;
71
+ top: -2px;
72
+ left: -2px;
73
+ right: -2px;
74
+ bottom: -2px;
75
+ background: linear-gradient(45deg, #00ff9d, #00b8ff);
76
+ z-index: -1;
77
+ filter: blur(10px);
78
+ opacity: 0.5;
79
+ }
80
+
81
+ .metric-container {
82
+ background: rgba(0, 0, 0, 0.8);
83
+ border: 2px solid #00ff9d;
84
+ border-radius: 10px;
85
+ padding: 20px;
86
+ margin: 10px 0;
87
+ position: relative;
88
+ overflow: hidden;
89
+ transition: all 0.3s ease;
90
+ }
91
+
92
+ .metric-container:hover {
93
+ transform: translateY(-5px);
94
+ box-shadow: 0 5px 15px rgba(0, 255, 157, 0.3);
95
+ }
96
+
97
+ .status-text {
98
+ font-family: 'Share Tech Mono', monospace;
99
+ color: #00ff9d;
100
+ font-size: 1.2em;
101
+ margin: 0;
102
+ text-shadow: 0 0 5px #00ff9d;
103
+ }
104
+
105
+ .sidebar .stSelectbox, .sidebar .stSlider {
106
+ background-color: rgba(0, 0, 0, 0.5);
107
+ border-radius: 5px;
108
+ padding: 15px;
109
+ margin: 10px 0;
110
+ border: 1px solid #00ff9d;
111
  }
112
 
113
  .stButton>button {
114
+ font-family: 'Orbitron', sans-serif;
115
  background: linear-gradient(45deg, #00ff9d, #00b8ff);
116
  color: black;
 
117
  border: none;
118
+ padding: 15px 30px;
119
  border-radius: 5px;
120
  text-transform: uppercase;
121
  font-weight: bold;
122
+ letter-spacing: 2px;
123
  transition: all 0.3s ease;
124
+ position: relative;
125
+ overflow: hidden;
126
  }
127
 
128
  .stButton>button:hover {
129
  transform: scale(1.05);
130
+ box-shadow: 0 0 20px rgba(0, 255, 157, 0.5);
131
  }
132
 
133
+ .stButton>button::after {
134
+ content: '';
135
+ position: absolute;
136
+ top: -50%;
137
+ left: -50%;
138
+ width: 200%;
139
+ height: 200%;
140
+ background: linear-gradient(
141
+ 45deg,
142
+ transparent,
143
+ rgba(255, 255, 255, 0.1),
144
+ transparent
145
+ );
146
+ transform: rotate(45deg);
147
+ animation: shine 3s infinite;
148
+ }
149
+
150
+ @keyframes shine {
151
+ 0% {
152
+ transform: translateX(-100%) rotate(45deg);
153
+ }
154
+ 100% {
155
+ transform: translateX(100%) rotate(45deg);
156
+ }
157
+ }
158
+
159
+ .custom-info-box {
160
+ background: rgba(0, 255, 157, 0.1);
161
+ border-left: 5px solid #00ff9d;
162
  padding: 15px;
163
  margin: 10px 0;
164
+ font-family: 'Share Tech Mono', monospace;
165
  }
166
 
167
+ .progress-bar-container {
168
+ width: 100%;
169
+ height: 30px;
170
+ background: rgba(0, 0, 0, 0.5);
171
+ border: 2px solid #00ff9d;
172
+ border-radius: 15px;
173
+ overflow: hidden;
174
+ position: relative;
175
  }
176
 
177
+ .progress-bar {
178
+ height: 100%;
179
+ background: linear-gradient(45deg, #00ff9d, #00b8ff);
180
+ transition: width 0.3s ease;
 
181
  }
182
  </style>
183
  """, unsafe_allow_html=True)
184
 
185
+ # Fixed prepare_dataset function
186
+ def prepare_dataset(data, tokenizer, block_size=128):
187
+ with error_handling("dataset preparation"):
188
+ def tokenize_function(examples):
189
+ return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ raw_dataset = Dataset.from_dict({'text': data})
192
+ tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text'])
193
+ tokenized_dataset = tokenized_dataset.map(
194
+ lambda examples: {'labels': examples['input_ids']},
195
+ batched=True
 
 
 
196
  )
197
+ tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
198
+ return tokenized_dataset
199
 
200
+ # Advanced Metrics Visualization
201
+ def create_training_metrics_plot(fitness_history):
202
+ fig = go.Figure()
203
+ fig.add_trace(go.Scatter(
204
+ y=fitness_history,
205
+ mode='lines+markers',
206
+ name='Loss',
207
+ line=dict(color='#00ff9d', width=2),
208
+ marker=dict(size=8, symbol='diamond'),
209
+ ))
210
+
211
+ fig.update_layout(
212
+ title={
213
+ 'text': 'Training Progress',
214
+ 'y':0.95,
215
+ 'x':0.5,
216
+ 'xanchor': 'center',
217
+ 'yanchor': 'top',
218
+ 'font': {'family': 'Orbitron', 'size': 24, 'color': '#00ff9d'}
219
+ },
220
+ paper_bgcolor='rgba(0,0,0,0.5)',
221
+ plot_bgcolor='rgba(0,0,0,0.3)',
222
+ font=dict(family='Share Tech Mono', color='#00ff9d'),
223
+ xaxis=dict(
224
+ title='Generation',
225
+ gridcolor='rgba(0,255,157,0.1)',
226
+ zerolinecolor='#00ff9d'
227
+ ),
228
+ yaxis=dict(
229
+ title='Loss',
230
+ gridcolor='rgba(0,255,157,0.1)',
231
+ zerolinecolor='#00ff9d'
232
+ ),
233
+ hovermode='x unified'
234
+ )
235
+ return fig
236
 
237
+ # Advanced Training Dashboard
238
+ class TrainingDashboard:
239
+ def __init__(self):
240
+ self.metrics = {
241
+ 'current_loss': 0,
242
+ 'best_loss': float('inf'),
243
+ 'generation': 0,
244
+ 'individual': 0,
245
+ 'start_time': time.time(),
246
+ 'training_speed': 0
247
+ }
248
+ self.history = []
249
+
250
+ def update(self, loss, generation, individual):
251
+ self.metrics['current_loss'] = loss
252
+ self.metrics['generation'] = generation
253
+ self.metrics['individual'] = individual
254
+ if loss < self.metrics['best_loss']:
255
+ self.metrics['best_loss'] = loss
256
+
257
+ elapsed_time = time.time() - self.metrics['start_time']
258
+ self.metrics['training_speed'] = (generation * individual) / elapsed_time
259
+ self.history.append({
260
+ 'loss': loss,
261
+ 'timestamp': datetime.now().strftime('%H:%M:%S')
262
+ })
263
+
264
+ def display(self):
265
+ col1, col2, col3 = st.columns(3)
266
+
267
+ with col1:
268
+ st.markdown("""
269
+ <div class="metric-container">
270
+ <h3 style="color: #00ff9d;">Current Status</h3>
271
+ <p class="status-text">Generation: {}/{}</p>
272
+ <p class="status-text">Individual: {}/{}</p>
273
+ </div>
274
+ """.format(
275
+ self.metrics['generation'],
276
+ self.metrics['total_generations'],
277
+ self.metrics['individual'],
278
+ self.metrics['population_size']
279
+ ), unsafe_allow_html=True)
280
+
281
+ with col2:
282
+ st.markdown("""
283
+ <div class="metric-container">
284
+ <h3 style="color: #00ff9d;">Performance</h3>
285
+ <p class="status-text">Current Loss: {:.4f}</p>
286
+ <p class="status-text">Best Loss: {:.4f}</p>
287
+ </div>
288
+ """.format(
289
+ self.metrics['current_loss'],
290
+ self.metrics['best_loss']
291
+ ), unsafe_allow_html=True)
292
+
293
+ with col3:
294
+ st.markdown("""
295
+ <div class="metric-container">
296
+ <h3 style="color: #00ff9d;">Training Metrics</h3>
297
+ <p class="status-text">Speed: {:.2f} iter/s</p>
298
+ <p class="status-text">Runtime: {:.2f}m</p>
299
+ </div>
300
+ """.format(
301
+ self.metrics['training_speed'],
302
+ (time.time() - self.metrics['start_time']) / 60
303
+ ), unsafe_allow_html=True)
304
 
305
  def main():
306
+ setup_advanced_cyberpunk_style()
307
 
308
  st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True)
309
+
310
+ # Initialize dashboard
311
+ dashboard = TrainingDashboard()
312
+
313
+ # Advanced Sidebar
314
  with st.sidebar:
315
+ st.markdown("""
316
+ <div style="text-align: center; padding: 20px;">
317
+ <h2 style="font-family: 'Orbitron'; color: #00ff9d;">Control Panel</h2>
318
+ </div>
319
+ """, unsafe_allow_html=True)
320
 
321
+ # Configuration Tabs
322
+ tab1, tab2, tab3 = st.tabs(["πŸ”§ Setup", "βš™οΈ Parameters", "πŸ“Š Monitoring"])
323
 
324
+ with tab1:
325
+ hf_token = st.text_input("πŸ”‘ HuggingFace Token", type="password")
326
+ repo_name = st.text_input("πŸ“ Repository Name", "my-gpt2-model")
327
+ data_source = st.selectbox('πŸ“Š Data Source', ('DEMO', 'Upload Text File'))
328
 
329
+ with tab2:
330
+ population_size = st.slider("Population Size", 4, 20, 6)
331
+ num_generations = st.slider("Generations", 1, 10, 3)
332
+ num_parents = st.slider("Parents", 2, population_size, 2)
333
+ mutation_rate = st.slider("Mutation Rate", 0.0, 1.0, 0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
+ # Advanced Parameters
336
+ with st.expander("πŸ”¬ Advanced Settings"):
337
+ learning_rate_min = st.number_input("Min Learning Rate", 1e-6, 1e-4, 1e-5)
338
+ learning_rate_max = st.number_input("Max Learning Rate", 1e-5, 1e-3, 5e-5)
339
+ batch_size_options = st.multiselect("Batch Sizes", [2, 4, 8, 16], default=[2, 4, 8])
340
+
341
+ with tab3:
342
+ st.markdown("""
343
+ <div class="cyber-box">
344
+ <h3 style="color: #00ff9d;">System Status</h3>
345
+ <p>GPU: {}</p>
346
+ <p>Memory Usage: {:.2f}GB</p>
347
+ </div>
348
+ """.format(
349
+ 'CUDA' if torch.cuda.is_available() else 'CPU',
350
+ torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
351
+ ), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ # [Rest of your existing main() function code here, integrated with the dashboard]
354
+ # Make sure to update the dashboard metrics during training
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ # Example of updating dashboard during training:
357
+ for generation in range(num_generations):
358
+ for idx, individual in enumerate(population):
359
+ # Your existing training code
360
+ fitness = fitness_function(individual, train_dataset, model_clone, tokenizer)
361
+ dashboard.update(fitness, generation + 1, idx + 1)
362
+ dashboard.display()
363
+
364
+ # Update progress
365
+ progress = (generation * len(population) + idx + 1) / (num_generations * len(population))
366
+ st.markdown(f"""
367
+ <div class="progress-bar-container">
368
+ <div class="progress-bar" style="width: {progress * 100}%"></div>
369
+ </div>
370
+ """, unsafe_allow_html=True)
371
 
372
  if __name__ == "__main__":
373
  main()