diabolic6045 commited on
Commit
9aef0f6
1 Parent(s): b1d82e7

Update app.py

Browse files

- Added Graphs for loss
- Added GA approch, and user can now select between standard and GA
- Added Paramerets for GA
- Added a inference box after the training is completed

Files changed (1) hide show
  1. app.py +230 -190
app.py CHANGED
@@ -3,109 +3,21 @@ import numpy as np
3
  import torch
4
  import random
5
  from transformers import (
6
- GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling,
7
- TrainerCallback # Import TrainerCallback here
8
  )
9
  from datasets import Dataset
10
  from huggingface_hub import HfApi
11
  import plotly.graph_objects as go
12
  import time
13
  from datetime import datetime
14
- import threading
15
-
16
 
17
  # Cyberpunk and Loading Animation Styling
18
  def setup_cyberpunk_style():
19
  st.markdown("""
20
  <style>
21
- body, button, input, select, textarea {
22
- font-family: 'Orbitron', sans-serif !important;
23
- color: #00ff9d !important;
24
- }
25
- .stApp {
26
- background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
27
- color: #00ff9d;
28
- font-family: 'Orbitron', sans-serif;
29
- font-size: 16px;
30
- line-height: 1.6;
31
- padding: 20px;
32
- box-sizing: border-box;
33
- }
34
-
35
- .main-title {
36
- text-align: center;
37
- font-size: 4em;
38
- color: #00ff9d;
39
- letter-spacing: 4px;
40
- animation: glow 2s ease-in-out infinite alternate;
41
- }
42
-
43
- @keyframes glow {
44
- from {text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d;}
45
- to {text-shadow: 0 0 15px #00b8ff, 0 0 20px #00b8ff;}
46
- }
47
- .stButton > button {
48
- font-family: 'Orbitron', sans-serif;
49
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
50
- color: #000;
51
- font-size: 1.1em;
52
- padding: 10px 20px;
53
- border: none;
54
- border-radius: 8px;
55
- transition: all 0.3s ease;
56
- }
57
-
58
- .stButton > button:hover {
59
- transform: scale(1.1);
60
- box-shadow: 0 0 20px rgba(0, 255, 157, 0.5);
61
- }
62
- .progress-bar-container {
63
- background: rgba(0, 0, 0, 0.5);
64
- border-radius: 15px;
65
- overflow: hidden;
66
- width: 100%;
67
- height: 30px;
68
- position: relative;
69
- margin: 10px 0;
70
- }
71
-
72
- .progress-bar {
73
- height: 100%;
74
- width: 0%;
75
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
76
- transition: width 0.5s ease;
77
- }
78
-
79
- .go-button {
80
- font-family: 'Orbitron', sans-serif;
81
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
82
- color: #000;
83
- font-size: 1.1em;
84
- padding: 10px 20px;
85
- border: none;
86
- border-radius: 8px;
87
- transition: all 0.3s ease;
88
- cursor: pointer;
89
- }
90
-
91
- .go-button:hover {
92
- transform: scale(1.1);
93
- box-shadow: 0 0 20px rgba(0, 255, 157, 0.5);
94
- }
95
-
96
- .loading-animation {
97
- display: inline-block;
98
- width: 20px;
99
- height: 20px;
100
- border: 3px solid #00ff9d;
101
- border-radius: 50%;
102
- border-top-color: transparent;
103
- animation: spin 1s ease-in-out infinite;
104
- }
105
-
106
- @keyframes spin {
107
- to {transform: rotate(360deg);}
108
- }
109
  </style>
110
  """, unsafe_allow_html=True)
111
 
@@ -121,30 +33,6 @@ def prepare_dataset(data, tokenizer, block_size=128):
121
  tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
122
  return tokenized_dataset
123
 
124
- # Training Dashboard Class with Enhanced Display
125
- class TrainingDashboard:
126
- def __init__(self):
127
- self.metrics = {
128
- 'current_loss': 0,
129
- 'best_loss': float('inf'),
130
- 'generation': 0,
131
- 'individual': 0,
132
- 'start_time': time.time(),
133
- 'training_speed': 0
134
- }
135
- self.history = []
136
-
137
- def update(self, loss, generation, individual):
138
- self.metrics['current_loss'] = loss
139
- self.metrics['generation'] = generation
140
- self.metrics['individual'] = individual
141
- if loss < self.metrics['best_loss']:
142
- self.metrics['best_loss'] = loss
143
-
144
- elapsed_time = time.time() - self.metrics['start_time']
145
- self.metrics['training_speed'] = (generation * individual) / elapsed_time
146
- self.history.append({'loss': loss, 'timestamp': datetime.now().strftime('%H:%M:%S')})
147
-
148
  # Define Model Initialization
149
  def initialize_model(model_name="gpt2"):
150
  model = GPT2LMHeadModel.from_pretrained(model_name)
@@ -155,66 +43,164 @@ def initialize_model(model_name="gpt2"):
155
  # Load Dataset Function with Uploaded File Option
156
  def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
157
  if data_source == "demo":
158
- data = ["In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.",
159
- "The rain falls in sheets, washing away the bloodstains from the alleyways.",
160
- "She plugs into the matrix, seeking answers to questions that have haunted her for years."]
161
- elif uploaded_file is not None:
 
 
162
  if uploaded_file.name.endswith(".txt"):
163
  data = [uploaded_file.read().decode("utf-8")]
164
  elif uploaded_file.name.endswith(".csv"):
165
- import pandas as pd
166
  df = pd.read_csv(uploaded_file)
167
- data = df[df.columns[0]].tolist() # assuming first column is text data
 
 
168
  else:
169
  data = ["No file uploaded. Please upload a dataset."]
170
 
171
  dataset = prepare_dataset(data, tokenizer)
172
  return dataset
173
 
174
- # Train Model Function with Customized Progress Bar
175
- def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, progress_callback=None):
176
- training_args = TrainingArguments(
177
- output_dir="./results",
178
- overwrite_output_dir=True,
179
- num_train_epochs=epochs,
180
- per_device_train_batch_size=batch_size,
181
- save_steps=10_000,
182
- save_total_limit=2,
183
- logging_dir="./logs",
184
- logging_steps=100,
185
- )
186
-
187
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- trainer = Trainer(
190
- model=model,
191
- args=training_args,
192
- data_collator=data_collator,
193
- train_dataset=train_dataset,
194
- callbacks=[ProgressCallback(progress_callback)]
195
- )
196
 
197
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- class ProgressCallback(TrainerCallback):
200
- def __init__(self, progress_callback):
201
- super().__init__()
202
- self.progress_callback = progress_callback
 
203
 
204
- def on_epoch_end(self, args, state, control, **kwargs):
205
- loss = state.log_history[-1]['loss']
206
- generation = state.global_step // args.gradient_accumulation_steps + 1
207
- individual = args.gradient_accumulation_steps
208
- self.progress_callback(loss, generation, individual)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # Main App Logic
211
  def main():
212
  setup_cyberpunk_style()
213
  st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True)
214
 
215
- # Initialize model and tokenizer
216
- model, tokenizer = initialize_model()
217
-
218
  # Sidebar Configuration with Additional Options
219
  with st.sidebar:
220
  st.markdown("### Configuration Panel")
@@ -225,7 +211,7 @@ def main():
225
  api = HfApi()
226
  api.set_access_token(hf_token)
227
  st.success("Hugging Face token added successfully!")
228
-
229
  # Training Parameters
230
  training_epochs = st.slider("Training Epochs", min_value=1, max_value=5, value=3)
231
  batch_size = st.slider("Batch Size", min_value=2, max_value=8, value=4)
@@ -235,8 +221,8 @@ def main():
235
  data_source = st.selectbox("Data Source", ("demo", "uploaded file"))
236
  uploaded_file = st.file_uploader("Upload a text file", type=["txt", "csv"]) if data_source == "uploaded file" else None
237
 
238
- custom_learning_rate = st.slider("Learning Rate", min_value=1e-6, max_value=5e-4, value=3e-5, step=1e-6)
239
-
240
  # Advanced Settings Toggle
241
  advanced_toggle = st.checkbox("Advanced Training Settings")
242
  if advanced_toggle:
@@ -245,47 +231,101 @@ def main():
245
  else:
246
  warmup_steps = 100
247
  weight_decay = 0.01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  # Load Dataset
250
  train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
251
 
252
- # Chatbot Interaction
253
- if st.checkbox("Enable Chatbot"):
254
- user_input = st.text_input("You:", placeholder="Type your message here...")
255
- if user_input:
256
- inputs = tokenizer(user_input, return_tensors="pt")
257
- outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1)
258
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
259
- st.write("Bot:", response)
260
-
261
  # Go Button to Start Training
262
  if st.button("Go"):
263
- progress_placeholder = st.empty()
264
- loading_animation = st.empty()
265
  st.markdown("### Model Training Progress")
 
 
 
 
 
 
 
 
 
266
 
267
- dashboard = TrainingDashboard()
 
 
268
 
269
- def train_progress(loss, generation, individual):
270
- progress = (generation + 1) / dashboard.metrics['training_epochs'] * 100
271
- progress_placeholder.markdown(f"""
272
- <div class="progress-bar-container">
273
- <div class="progress-bar" style="width: {progress}%;"></div>
274
- </div>
275
- """, unsafe_allow_html=True)
276
- dashboard.update(loss=loss, generation=generation, individual=individual)
277
 
278
- thread = threading.Thread(target=train_model, args=(model, train_dataset, tokenizer, training_epochs, batch_size, train_progress))
279
- thread.start()
280
- loading_animation.markdown("""
281
- <div class="loading-animation"></div>
282
- """, unsafe_allow_html=True)
283
- thread.join()
284
-
285
- loading_animation.empty()
286
- st.success("Training Complete!")
287
- st.write("Training Metrics:")
288
- st.write(dashboard.metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  if __name__ == "__main__":
291
  main()
 
3
  import torch
4
  import random
5
  from transformers import (
6
+ GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
 
7
  )
8
  from datasets import Dataset
9
  from huggingface_hub import HfApi
10
  import plotly.graph_objects as go
11
  import time
12
  from datetime import datetime
13
+ from typing import Dict, List, Any
14
+ import pandas as pd # Added pandas import
15
 
16
  # Cyberpunk and Loading Animation Styling
17
  def setup_cyberpunk_style():
18
  st.markdown("""
19
  <style>
20
+ /* [Your existing CSS styles here] */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  </style>
22
  """, unsafe_allow_html=True)
23
 
 
33
  tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
34
  return tokenized_dataset
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Define Model Initialization
37
  def initialize_model(model_name="gpt2"):
38
  model = GPT2LMHeadModel.from_pretrained(model_name)
 
43
  # Load Dataset Function with Uploaded File Option
44
  def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
45
  if data_source == "demo":
46
+ data = [
47
+ "In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.",
48
+ "The rain falls in sheets, washing away the bloodstains from the alleyways.",
49
+ "She plugs into the matrix, seeking answers to questions that have haunted her for years."
50
+ ]
51
+ elif data_source == "uploaded file" and uploaded_file is not None:
52
  if uploaded_file.name.endswith(".txt"):
53
  data = [uploaded_file.read().decode("utf-8")]
54
  elif uploaded_file.name.endswith(".csv"):
 
55
  df = pd.read_csv(uploaded_file)
56
+ data = df[df.columns[0]].astype(str).tolist() # Ensure all data is string
57
+ else:
58
+ data = ["Unsupported file format."]
59
  else:
60
  data = ["No file uploaded. Please upload a dataset."]
61
 
62
  dataset = prepare_dataset(data, tokenizer)
63
  return dataset
64
 
65
+ # Train Model Function
66
+ def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, use_ga=False, ga_params=None):
67
+ if not use_ga:
68
+ training_args = TrainingArguments(
69
+ output_dir="./results",
70
+ overwrite_output_dir=True,
71
+ num_train_epochs=epochs,
72
+ per_device_train_batch_size=batch_size,
73
+ save_steps=10_000,
74
+ save_total_limit=2,
75
+ logging_dir="./logs",
76
+ logging_steps=1,
77
+ logging_strategy='steps',
78
+ report_to=None, # Disable default logging to WandB or other services
79
+ )
80
+
81
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
82
+
83
+ trainer = Trainer(
84
+ model=model,
85
+ args=training_args,
86
+ data_collator=data_collator,
87
+ train_dataset=train_dataset,
88
+ )
89
+ trainer.train()
90
+ return trainer.state.log_history
91
+ else:
92
+ # GA training logic
93
+ param_bounds = {
94
+ 'learning_rate': (1e-5, 5e-5),
95
+ 'epochs': (1, ga_params['max_epochs']),
96
+ 'batch_size': [2, 4, 8, 16]
97
+ }
98
 
99
+ population = create_ga_population(ga_params['population_size'], param_bounds)
100
+ best_individual = None
101
+ best_fitness = float('inf')
102
+ all_losses = []
 
 
 
103
 
104
+ for generation in range(ga_params['num_generations']):
105
+ fitnesses = []
106
+ for idx, individual in enumerate(population):
107
+ model_copy = GPT2LMHeadModel.from_pretrained('gpt2')
108
+ training_args = TrainingArguments(
109
+ output_dir=f"./results/ga_{generation}_{idx}",
110
+ num_train_epochs=individual['epochs'],
111
+ per_device_train_batch_size=individual['batch_size'],
112
+ learning_rate=individual['learning_rate'],
113
+ logging_steps=1,
114
+ logging_strategy='steps',
115
+ report_to=None, # Disable default logging to WandB or other services
116
+ )
117
 
118
+ trainer = Trainer(
119
+ model=model_copy,
120
+ args=training_args,
121
+ train_dataset=train_dataset,
122
+ )
123
 
124
+ # Capture the training result
125
+ train_result = trainer.train()
126
+
127
+ # Safely retrieve the training loss
128
+ fitness = train_result.metrics.get('train_loss', None)
129
+ if fitness is None:
130
+ # If 'train_loss' is not available, try to compute it from log history
131
+ if 'loss' in trainer.state.log_history[-1]:
132
+ fitness = trainer.state.log_history[-1]['loss']
133
+ else:
134
+ fitness = float('inf') # Assign a large number if loss is not available
135
+
136
+ fitnesses.append(fitness)
137
+ all_losses.extend(trainer.state.log_history)
138
+
139
+ if fitness < best_fitness:
140
+ best_fitness = fitness
141
+ best_individual = individual
142
+ model.load_state_dict(model_copy.state_dict())
143
+
144
+ del model_copy
145
+ torch.cuda.empty_cache()
146
+
147
+ # GA operations
148
+ parents = select_ga_parents(population, fitnesses, ga_params['num_parents'])
149
+ offspring_size = ga_params['population_size'] - ga_params['num_parents']
150
+ offspring = ga_crossover(parents, offspring_size)
151
+ offspring = ga_mutation(offspring, param_bounds, ga_params['mutation_rate'])
152
+ population = parents + offspring
153
+
154
+ return all_losses
155
+
156
+ # GA-related functions
157
+ def create_ga_population(size: int, param_bounds: Dict[str, Any]) -> List[Dict[str, Any]]:
158
+ """Create initial population for genetic algorithm"""
159
+ population = []
160
+ for _ in range(size):
161
+ individual = {
162
+ 'learning_rate': random.uniform(*param_bounds['learning_rate']),
163
+ 'epochs': random.randint(*param_bounds['epochs']),
164
+ 'batch_size': random.choice(param_bounds['batch_size']),
165
+ }
166
+ population.append(individual)
167
+ return population
168
+
169
+ def select_ga_parents(population: List[Dict[str, Any]], fitnesses: List[float], num_parents: int) -> List[Dict[str, Any]]:
170
+ """Select best performing individuals as parents"""
171
+ parents = [population[i] for i in np.argsort(fitnesses)[:num_parents]]
172
+ return parents
173
+
174
+ def ga_crossover(parents: List[Dict[str, Any]], offspring_size: int) -> List[Dict[str, Any]]:
175
+ """Create offspring through crossover of parents"""
176
+ offspring = []
177
+ for _ in range(offspring_size):
178
+ parent1 = random.choice(parents)
179
+ parent2 = random.choice(parents)
180
+ child = {
181
+ 'learning_rate': random.choice([parent1['learning_rate'], parent2['learning_rate']]),
182
+ 'epochs': random.choice([parent1['epochs'], parent2['epochs']]),
183
+ 'batch_size': random.choice([parent1['batch_size'], parent2['batch_size']]),
184
+ }
185
+ offspring.append(child)
186
+ return offspring
187
+
188
+ def ga_mutation(offspring: List[Dict[str, Any]], param_bounds: Dict[str, Any], mutation_rate: float = 0.1) -> List[Dict[str, Any]]:
189
+ """Apply random mutations to offspring"""
190
+ for individual in offspring:
191
+ if random.random() < mutation_rate:
192
+ individual['learning_rate'] = random.uniform(*param_bounds['learning_rate'])
193
+ if random.random() < mutation_rate:
194
+ individual['epochs'] = random.randint(*param_bounds['epochs'])
195
+ if random.random() < mutation_rate:
196
+ individual['batch_size'] = random.choice(param_bounds['batch_size'])
197
+ return offspring
198
 
199
  # Main App Logic
200
  def main():
201
  setup_cyberpunk_style()
202
  st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True)
203
 
 
 
 
204
  # Sidebar Configuration with Additional Options
205
  with st.sidebar:
206
  st.markdown("### Configuration Panel")
 
211
  api = HfApi()
212
  api.set_access_token(hf_token)
213
  st.success("Hugging Face token added successfully!")
214
+
215
  # Training Parameters
216
  training_epochs = st.slider("Training Epochs", min_value=1, max_value=5, value=3)
217
  batch_size = st.slider("Batch Size", min_value=2, max_value=8, value=4)
 
221
  data_source = st.selectbox("Data Source", ("demo", "uploaded file"))
222
  uploaded_file = st.file_uploader("Upload a text file", type=["txt", "csv"]) if data_source == "uploaded file" else None
223
 
224
+ custom_learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=5e-4, value=3e-5, step=1e-6, format="%.6f")
225
+
226
  # Advanced Settings Toggle
227
  advanced_toggle = st.checkbox("Advanced Training Settings")
228
  if advanced_toggle:
 
231
  else:
232
  warmup_steps = 100
233
  weight_decay = 0.01
234
+
235
+ # Add training method selection
236
+ training_method = st.selectbox("Training Method", ("Standard", "Genetic Algorithm"))
237
+
238
+ if training_method == "Genetic Algorithm":
239
+ st.markdown("### GA Parameters")
240
+ ga_params = {
241
+ 'population_size': st.slider("Population Size", min_value=4, max_value=10, value=6),
242
+ 'num_generations': st.slider("Number of Generations", min_value=1, max_value=5, value=3),
243
+ 'num_parents': st.slider("Number of Parents", min_value=2, max_value=4, value=2),
244
+ 'mutation_rate': st.slider("Mutation Rate", min_value=0.0, max_value=1.0, value=0.1),
245
+ 'max_epochs': training_epochs
246
+ }
247
+ else:
248
+ ga_params = None
249
+
250
+ # Initialize model and tokenizer
251
+ if 'model' not in st.session_state:
252
+ model, tokenizer = initialize_model(model_name=model_choice)
253
+ st.session_state['model'] = model
254
+ st.session_state['tokenizer'] = tokenizer
255
+ st.session_state['model_name'] = model_choice
256
+ else:
257
+ if st.session_state.get('model_name') != model_choice:
258
+ model, tokenizer = initialize_model(model_name=model_choice)
259
+ st.session_state['model'] = model
260
+ st.session_state['tokenizer'] = tokenizer
261
+ st.session_state['model_name'] = model_choice
262
+ else:
263
+ model = st.session_state['model']
264
+ tokenizer = st.session_state['tokenizer']
265
 
266
  # Load Dataset
267
  train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
268
 
 
 
 
 
 
 
 
 
 
269
  # Go Button to Start Training
270
  if st.button("Go"):
 
 
271
  st.markdown("### Model Training Progress")
272
+ progress_bar = st.progress(0)
273
+ status_text = st.empty()
274
+ status_text.text("Training in progress...")
275
+
276
+ # Train the model
277
+ if training_method == "Standard":
278
+ logs = train_model(model, train_dataset, tokenizer, training_epochs, batch_size)
279
+ else:
280
+ logs = train_model(model, train_dataset, tokenizer, training_epochs, batch_size, use_ga=True, ga_params=ga_params)
281
 
282
+ # Update progress bar to 100%
283
+ progress_bar.progress(100)
284
+ status_text.text("Training complete!")
285
 
286
+ # Store the model and logs in st.session_state
287
+ st.session_state['model'] = model
288
+ st.session_state['logs'] = logs
 
 
 
 
 
289
 
290
+ # Plot the losses if available
291
+ if 'logs' in st.session_state:
292
+ logs = st.session_state['logs']
293
+ losses = [log['loss'] for log in logs if 'loss' in log]
294
+ steps = list(range(len(losses)))
295
+ if losses:
296
+ # Plot the losses
297
+ fig = go.Figure()
298
+ fig.add_trace(go.Scatter(x=steps, y=losses, mode='lines+markers', name='Training Loss', line=dict(color='#00ff9d')))
299
+ fig.update_layout(
300
+ title="Training Progress",
301
+ xaxis_title="Training Steps",
302
+ yaxis_title="Loss",
303
+ template="plotly_dark",
304
+ plot_bgcolor='rgba(0,0,0,0)',
305
+ paper_bgcolor='rgba(0,0,0,0)',
306
+ font=dict(color='#00ff9d')
307
+ )
308
+ st.plotly_chart(fig, use_container_width=True)
309
+ else:
310
+ st.write("No loss data available to plot.")
311
+ else:
312
+ st.write("Train the model to see the loss plot.")
313
+
314
+ # After training, you can use the model for inference
315
+ st.markdown("### Model Inference")
316
+ with st.form("inference_form"):
317
+ user_input = st.text_input("Enter prompt for the model:")
318
+ submitted = st.form_submit_button("Generate")
319
+ if submitted:
320
+ if 'model' in st.session_state:
321
+ model = st.session_state['model']
322
+ tokenizer = st.session_state['tokenizer']
323
+ inputs = tokenizer(user_input, return_tensors="pt")
324
+ outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1)
325
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
326
+ st.write("Model output:", response)
327
+ else:
328
+ st.write("Please train the model first.")
329
 
330
  if __name__ == "__main__":
331
  main()