nouamanetazi HF Staff commited on
Commit
c45c066
·
verified ·
1 Parent(s): 0a908c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -81,7 +81,7 @@ usage_stats = {
81
  }
82
 
83
  @spaces.GPU
84
- def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
85
  if not prompt.strip():
86
  logger.warning("Empty prompt submitted")
87
  return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
@@ -91,29 +91,26 @@ def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150,
91
 
92
  start_time = time.time()
93
 
94
- # Simply use the progress.tqdm wrapper for automatic progress tracking
95
- with progress.tqdm("توليد النص (Generating text)", total=1) as pbar:
96
- # Tokenize input
97
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
98
-
99
- # Generate text
100
- output = model.generate(
101
- **inputs,
102
- max_length=max_length,
103
- temperature=temperature,
104
- top_p=top_p,
105
- do_sample=True,
106
- repetition_penalty=repetition_penalty,
107
- num_beams=num_beams,
108
- top_k=top_k,
109
- early_stopping=True,
110
- pad_token_id=tokenizer.pad_token_id,
111
- eos_token_id=tokenizer.eos_token_id,
112
- )
113
-
114
- # Decode output
115
- result = tokenizer.decode(output[0], skip_special_tokens=True)
116
- pbar.update(1)
117
 
118
  # Update stats
119
  generation_time = time.time() - start_time
 
81
  }
82
 
83
  @spaces.GPU
84
+ def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5):
85
  if not prompt.strip():
86
  logger.warning("Empty prompt submitted")
87
  return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
 
91
 
92
  start_time = time.time()
93
 
94
+ # Tokenize input
95
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
96
+
97
+ # Generate text
98
+ output = model.generate(
99
+ **inputs,
100
+ max_length=max_length,
101
+ temperature=temperature,
102
+ top_p=top_p,
103
+ do_sample=True,
104
+ repetition_penalty=repetition_penalty,
105
+ num_beams=num_beams,
106
+ top_k=top_k,
107
+ early_stopping=True,
108
+ pad_token_id=tokenizer.pad_token_id,
109
+ eos_token_id=tokenizer.eos_token_id,
110
+ )
111
+
112
+ # Decode output
113
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
114
 
115
  # Update stats
116
  generation_time = time.time() - start_time