Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,15 +25,15 @@ def generate_text(prompt, temperature, top_p):
|
|
| 25 |
for _ in range(80): # Adjust the range to control the number of tokens generated
|
| 26 |
with torch.no_grad():
|
| 27 |
outputs = model(input_tokens)
|
| 28 |
-
predictions = outputs.logits / temperature
|
| 29 |
-
sorted_logits, sorted_indices = torch.sort(predictions
|
| 30 |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 31 |
sorted_indices_to_remove = cumulative_probs > top_p
|
| 32 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 33 |
sorted_indices_to_remove[..., 0] = 0
|
| 34 |
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 35 |
-
predictions[:,
|
| 36 |
-
next_token = torch.multinomial(torch.softmax(predictions
|
| 37 |
|
| 38 |
input_tokens = torch.cat((input_tokens, next_token), dim=1)
|
| 39 |
|
|
|
|
| 25 |
for _ in range(80): # Adjust the range to control the number of tokens generated
|
| 26 |
with torch.no_grad():
|
| 27 |
outputs = model(input_tokens)
|
| 28 |
+
predictions = outputs.logits[:, -1, :] / temperature
|
| 29 |
+
sorted_logits, sorted_indices = torch.sort(predictions, descending=True)
|
| 30 |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 31 |
sorted_indices_to_remove = cumulative_probs > top_p
|
| 32 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 33 |
sorted_indices_to_remove[..., 0] = 0
|
| 34 |
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 35 |
+
predictions[:, indices_to_remove] = -float('Inf')
|
| 36 |
+
next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
|
| 37 |
|
| 38 |
input_tokens = torch.cat((input_tokens, next_token), dim=1)
|
| 39 |
|