Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -79,6 +79,7 @@ usage_stats = {
|
|
79 |
"total_tokens_generated": 0,
|
80 |
"start_time": time.time()
|
81 |
}
|
|
|
82 |
@spaces.GPU
|
83 |
def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
|
84 |
if not prompt.strip():
|
@@ -97,24 +98,22 @@ def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150,
|
|
97 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
98 |
progress(0.1, desc="تحليل النص (Tokenizing)")
|
99 |
|
100 |
-
# Generate text
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
eos_token_id=tokenizer.eos_token_id,
|
117 |
-
)
|
118 |
|
119 |
# Decode output
|
120 |
progress(0.9, desc="معالجة النتائج (Processing results)")
|
@@ -192,7 +191,7 @@ def get_stats():
|
|
192 |
def reset_params():
|
193 |
"""Reset parameters to default values"""
|
194 |
logger.info("Parameters reset to defaults")
|
195 |
-
return
|
196 |
|
197 |
def thumbs_up_callback(input_text, output_text):
|
198 |
"""Record positive feedback"""
|
@@ -274,15 +273,15 @@ if __name__ == "__main__":
|
|
274 |
with gr.Accordion("معلمات التوليد (Generation Parameters)", open=False):
|
275 |
with gr.Row():
|
276 |
with gr.Column():
|
277 |
-
max_length = gr.Slider(8, 4096, value=
|
278 |
temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature (درجة الحرارة)")
|
279 |
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p (أعلى احتمال)")
|
280 |
|
281 |
with gr.Column():
|
282 |
-
top_k = gr.Slider(1, 10000, value=
|
283 |
-
num_beams = gr.Slider(1, 20, value=
|
284 |
-
repetition_penalty = gr.Slider(0.0, 100.0, value=1.
|
285 |
-
|
286 |
with gr.Column(scale=6):
|
287 |
output_text = gr.Textbox(label="النص المولد (Generated Text)", lines=10)
|
288 |
generation_info = gr.Markdown("")
|
|
|
79 |
"total_tokens_generated": 0,
|
80 |
"start_time": time.time()
|
81 |
}
|
82 |
+
|
83 |
@spaces.GPU
|
84 |
def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
|
85 |
if not prompt.strip():
|
|
|
98 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
99 |
progress(0.1, desc="تحليل النص (Tokenizing)")
|
100 |
|
101 |
+
# Generate text with optimized parameters for speed
|
102 |
+
progress(0.2, desc="توليد النص (Generating text)")
|
103 |
+
output = model.generate(
|
104 |
+
**inputs,
|
105 |
+
max_length=max_length,
|
106 |
+
temperature=temperature,
|
107 |
+
top_p=top_p,
|
108 |
+
do_sample=True,
|
109 |
+
repetition_penalty=repetition_penalty,
|
110 |
+
num_beams=1 if num_beams > 4 else num_beams, # Reduce beam search or use greedy decoding
|
111 |
+
top_k=top_k,
|
112 |
+
early_stopping=True,
|
113 |
+
pad_token_id=tokenizer.pad_token_id,
|
114 |
+
eos_token_id=tokenizer.eos_token_id,
|
115 |
+
use_cache=True, # Ensure cache is used
|
116 |
+
)
|
|
|
|
|
117 |
|
118 |
# Decode output
|
119 |
progress(0.9, desc="معالجة النتائج (Processing results)")
|
|
|
191 |
def reset_params():
|
192 |
"""Reset parameters to default values"""
|
193 |
logger.info("Parameters reset to defaults")
|
194 |
+
return 128, 0.7, 0.9, 50, 1, 1.2 # Updated defaults for faster generation
|
195 |
|
196 |
def thumbs_up_callback(input_text, output_text):
|
197 |
"""Record positive feedback"""
|
|
|
273 |
with gr.Accordion("معلمات التوليد (Generation Parameters)", open=False):
|
274 |
with gr.Row():
|
275 |
with gr.Column():
|
276 |
+
max_length = gr.Slider(8, 4096, value=128, label="Max Length (الطول الأقصى)") # Reduced default
|
277 |
temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature (درجة الحرارة)")
|
278 |
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p (أعلى احتمال)")
|
279 |
|
280 |
with gr.Column():
|
281 |
+
top_k = gr.Slider(1, 10000, value=50, label="Top-k (أعلى ك)") # Reduced default
|
282 |
+
num_beams = gr.Slider(1, 20, value=1, label="Number of Beams (عدد الأشعة)") # Reduced default
|
283 |
+
repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty (عقوبة التكرار)") # Reduced default
|
284 |
+
|
285 |
with gr.Column(scale=6):
|
286 |
output_text = gr.Textbox(label="النص المولد (Generated Text)", lines=10)
|
287 |
generation_info = gr.Markdown("")
|