Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,18 +8,24 @@ from huggingface_hub import HfApi
|
|
8 |
import plotly.graph_objects as go
|
9 |
import time
|
10 |
from datetime import datetime
|
|
|
11 |
|
12 |
# Cyberpunk and Loading Animation Styling
|
13 |
def setup_cyberpunk_style():
|
14 |
st.markdown("""
|
15 |
<style>
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
19 |
.stApp {
|
20 |
background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
|
21 |
color: #00ff9d;
|
22 |
font-family: 'Orbitron', sans-serif;
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
|
25 |
.main-title {
|
@@ -145,7 +151,9 @@ def initialize_model(model_name="gpt2"):
|
|
145 |
# Load Dataset Function with Uploaded File Option
|
146 |
def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
|
147 |
if data_source == "demo":
|
148 |
-
data = ["
|
|
|
|
|
149 |
elif uploaded_file is not None:
|
150 |
if uploaded_file.name.endswith(".txt"):
|
151 |
data = [uploaded_file.read().decode("utf-8")]
|
@@ -160,7 +168,7 @@ def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
|
|
160 |
return dataset
|
161 |
|
162 |
# Train Model Function with Customized Progress Bar
|
163 |
-
def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4):
|
164 |
training_args = TrainingArguments(
|
165 |
output_dir="./results",
|
166 |
overwrite_output_dir=True,
|
@@ -179,14 +187,26 @@ def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4):
|
|
179 |
args=training_args,
|
180 |
data_collator=data_collator,
|
181 |
train_dataset=train_dataset,
|
|
|
182 |
)
|
183 |
|
184 |
trainer.train()
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
# Main App Logic
|
187 |
def main():
|
188 |
setup_cyberpunk_style()
|
189 |
-
st.markdown('<h1 class="main-title">
|
190 |
|
191 |
# Initialize model and tokenizer
|
192 |
model, tokenizer = initialize_model()
|
@@ -225,6 +245,15 @@ def main():
|
|
225 |
# Load Dataset
|
226 |
train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# Go Button to Start Training
|
229 |
if st.button("Go"):
|
230 |
progress_placeholder = st.empty()
|
@@ -233,22 +262,21 @@ def main():
|
|
233 |
|
234 |
dashboard = TrainingDashboard()
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
<div class="loading-animation"></div>
|
239 |
-
""", unsafe_allow_html=True)
|
240 |
-
|
241 |
-
train_model(model, train_dataset, tokenizer, epochs=1, batch_size=batch_size)
|
242 |
-
|
243 |
-
# Update Progress Bar
|
244 |
-
progress = (epoch + 1) / training_epochs * 100
|
245 |
progress_placeholder.markdown(f"""
|
246 |
<div class="progress-bar-container">
|
247 |
<div class="progress-bar" style="width: {progress}%;"></div>
|
248 |
</div>
|
249 |
""", unsafe_allow_html=True)
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
loading_animation.empty()
|
254 |
st.success("Training Complete!")
|
|
|
8 |
import plotly.graph_objects as go
|
9 |
import time
|
10 |
from datetime import datetime
|
11 |
+
import threading
|
12 |
|
13 |
# Cyberpunk and Loading Animation Styling
|
14 |
def setup_cyberpunk_style():
|
15 |
st.markdown("""
|
16 |
<style>
|
17 |
+
body, button, input, select, textarea {
|
18 |
+
font-family: 'Orbitron', sans-serif !important;
|
19 |
+
color: #00ff9d !important;
|
20 |
+
}
|
21 |
.stApp {
|
22 |
background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%);
|
23 |
color: #00ff9d;
|
24 |
font-family: 'Orbitron', sans-serif;
|
25 |
+
font-size: 16px;
|
26 |
+
line-height: 1.6;
|
27 |
+
padding: 20px;
|
28 |
+
box-sizing: border-box;
|
29 |
}
|
30 |
|
31 |
.main-title {
|
|
|
151 |
# Load Dataset Function with Uploaded File Option
|
152 |
def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None):
|
153 |
if data_source == "demo":
|
154 |
+
data = ["In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.",
|
155 |
+
"The rain falls in sheets, washing away the bloodstains from the alleyways.",
|
156 |
+
"She plugs into the matrix, seeking answers to questions that have haunted her for years."]
|
157 |
elif uploaded_file is not None:
|
158 |
if uploaded_file.name.endswith(".txt"):
|
159 |
data = [uploaded_file.read().decode("utf-8")]
|
|
|
168 |
return dataset
|
169 |
|
170 |
# Train Model Function with Customized Progress Bar
|
171 |
+
def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, progress_callback=None):
|
172 |
training_args = TrainingArguments(
|
173 |
output_dir="./results",
|
174 |
overwrite_output_dir=True,
|
|
|
187 |
args=training_args,
|
188 |
data_collator=data_collator,
|
189 |
train_dataset=train_dataset,
|
190 |
+
callbacks=[ProgressCallback(progress_callback)]
|
191 |
)
|
192 |
|
193 |
trainer.train()
|
194 |
|
195 |
+
class ProgressCallback(TrainerCallback):
|
196 |
+
def __init__(self, progress_callback):
|
197 |
+
super().__init__()
|
198 |
+
self.progress_callback = progress_callback
|
199 |
+
|
200 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
201 |
+
loss = state.log_history[-1]['loss']
|
202 |
+
generation = state.global_step // args.gradient_accumulation_steps + 1
|
203 |
+
individual = args.gradient_accumulation_steps
|
204 |
+
self.progress_callback(loss, generation, individual)
|
205 |
+
|
206 |
# Main App Logic
|
207 |
def main():
|
208 |
setup_cyberpunk_style()
|
209 |
+
st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True)
|
210 |
|
211 |
# Initialize model and tokenizer
|
212 |
model, tokenizer = initialize_model()
|
|
|
245 |
# Load Dataset
|
246 |
train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
|
247 |
|
248 |
+
# Chatbot Interaction
|
249 |
+
if st.checkbox("Enable Chatbot"):
|
250 |
+
user_input = st.text_input("You:", placeholder="Type your message here...")
|
251 |
+
if user_input:
|
252 |
+
inputs = tokenizer(user_input, return_tensors="pt")
|
253 |
+
outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1)
|
254 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
255 |
+
st.write("Bot:", response)
|
256 |
+
|
257 |
# Go Button to Start Training
|
258 |
if st.button("Go"):
|
259 |
progress_placeholder = st.empty()
|
|
|
262 |
|
263 |
dashboard = TrainingDashboard()
|
264 |
|
265 |
+
def train_progress(loss, generation, individual):
|
266 |
+
progress = (generation + 1) / dashboard.metrics['training_epochs'] * 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
progress_placeholder.markdown(f"""
|
268 |
<div class="progress-bar-container">
|
269 |
<div class="progress-bar" style="width: {progress}%;"></div>
|
270 |
</div>
|
271 |
""", unsafe_allow_html=True)
|
272 |
+
dashboard.update(loss=loss, generation=generation, individual=individual)
|
273 |
+
|
274 |
+
thread = threading.Thread(target=train_model, args=(model, train_dataset, tokenizer, training_epochs, batch_size, train_progress))
|
275 |
+
thread.start()
|
276 |
+
loading_animation.markdown("""
|
277 |
+
<div class="loading-animation"></div>
|
278 |
+
""", unsafe_allow_html=True)
|
279 |
+
thread.join()
|
280 |
|
281 |
loading_animation.empty()
|
282 |
st.success("Training Complete!")
|