nouamanetazi HF Staff commited on
Commit
ea7e643
·
verified ·
1 Parent(s): 8cafaac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -79,6 +79,7 @@ usage_stats = {
79
  "total_tokens_generated": 0,
80
  "start_time": time.time()
81
  }
 
82
  @spaces.GPU
83
  def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
84
  if not prompt.strip():
@@ -97,24 +98,22 @@ def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150,
97
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
98
  progress(0.1, desc="تحليل النص (Tokenizing)")
99
 
100
- # Generate text
101
- # Since we can't track token generation directly, we'll create artificial steps
102
- steps = 10 # Divide generation into 10 steps
103
- for i in progress.tqdm(range(steps), desc="توليد النص (Generating text)"):
104
- if i == 0: # Only generate on the first step
105
- output = model.generate(
106
- **inputs,
107
- max_length=max_length,
108
- temperature=temperature,
109
- top_p=top_p,
110
- do_sample=True,
111
- repetition_penalty=repetition_penalty,
112
- num_beams=num_beams,
113
- top_k=top_k,
114
- early_stopping=True,
115
- pad_token_id=tokenizer.pad_token_id,
116
- eos_token_id=tokenizer.eos_token_id,
117
- )
118
 
119
  # Decode output
120
  progress(0.9, desc="معالجة النتائج (Processing results)")
@@ -192,7 +191,7 @@ def get_stats():
192
  def reset_params():
193
  """Reset parameters to default values"""
194
  logger.info("Parameters reset to defaults")
195
- return 256, 0.7, 0.9, 150, 8, 1.5
196
 
197
  def thumbs_up_callback(input_text, output_text):
198
  """Record positive feedback"""
@@ -274,15 +273,15 @@ if __name__ == "__main__":
274
  with gr.Accordion("معلمات التوليد (Generation Parameters)", open=False):
275
  with gr.Row():
276
  with gr.Column():
277
- max_length = gr.Slider(8, 4096, value=256, label="Max Length (الطول الأقصى)")
278
  temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature (درجة الحرارة)")
279
  top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p (أعلى احتمال)")
280
 
281
  with gr.Column():
282
- top_k = gr.Slider(1, 10000, value=150, label="Top-k (أعلى ك)")
283
- num_beams = gr.Slider(1, 20, value=8, label="Number of Beams (عدد الأشعة)")
284
- repetition_penalty = gr.Slider(0.0, 100.0, value=1.5, label="Repetition Penalty (عقوبة التكرار)")
285
-
286
  with gr.Column(scale=6):
287
  output_text = gr.Textbox(label="النص المولد (Generated Text)", lines=10)
288
  generation_info = gr.Markdown("")
 
79
  "total_tokens_generated": 0,
80
  "start_time": time.time()
81
  }
82
+
83
  @spaces.GPU
84
  def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
85
  if not prompt.strip():
 
98
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
99
  progress(0.1, desc="تحليل النص (Tokenizing)")
100
 
101
+ # Generate text with optimized parameters for speed
102
+ progress(0.2, desc="توليد النص (Generating text)")
103
+ output = model.generate(
104
+ **inputs,
105
+ max_length=max_length,
106
+ temperature=temperature,
107
+ top_p=top_p,
108
+ do_sample=True,
109
+ repetition_penalty=repetition_penalty,
110
+ num_beams=1 if num_beams > 4 else num_beams, # Reduce beam search or use greedy decoding
111
+ top_k=top_k,
112
+ early_stopping=True,
113
+ pad_token_id=tokenizer.pad_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ use_cache=True, # Ensure cache is used
116
+ )
 
 
117
 
118
  # Decode output
119
  progress(0.9, desc="معالجة النتائج (Processing results)")
 
191
  def reset_params():
192
  """Reset parameters to default values"""
193
  logger.info("Parameters reset to defaults")
194
+ return 128, 0.7, 0.9, 50, 1, 1.2 # Updated defaults for faster generation
195
 
196
  def thumbs_up_callback(input_text, output_text):
197
  """Record positive feedback"""
 
273
  with gr.Accordion("معلمات التوليد (Generation Parameters)", open=False):
274
  with gr.Row():
275
  with gr.Column():
276
+ max_length = gr.Slider(8, 4096, value=128, label="Max Length (الطول الأقصى)") # Reduced default
277
  temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature (درجة الحرارة)")
278
  top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p (أعلى احتمال)")
279
 
280
  with gr.Column():
281
+ top_k = gr.Slider(1, 10000, value=50, label="Top-k (أعلى ك)") # Reduced default
282
+ num_beams = gr.Slider(1, 20, value=1, label="Number of Beams (عدد الأشعة)") # Reduced default
283
+ repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty (عقوبة التكرار)") # Reduced default
284
+
285
  with gr.Column(scale=6):
286
  output_text = gr.Textbox(label="النص المولد (Generated Text)", lines=10)
287
  generation_info = gr.Markdown("")