Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -25,28 +25,32 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf'
|
|
25 |
logits[indices_to_remove] = filter_value
|
26 |
return logits
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def generate(input_text):
|
29 |
-
result = []
|
30 |
-
for i in range(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
input_ids.extend( tokenizer.encode(input_text + "-", add_special_tokens=False) )
|
35 |
-
input_ids = torch.tensor( [input_ids] )
|
36 |
-
|
37 |
-
output = model(input_ids)
|
38 |
-
|
39 |
-
next_token_logits = output.logits[0, -1, :]
|
40 |
-
next_token_logits[ tokenizer.convert_tokens_to_ids('[UNK]') ] = -float('Inf')
|
41 |
-
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=8, top_p=1)
|
42 |
-
next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
|
43 |
-
if next_token == tokenizer.sep_token_id:
|
44 |
-
break
|
45 |
-
generated.append( next_token.item() )
|
46 |
-
input_ids = torch.cat( (input_ids, next_token.unsqueeze(0)), dim=1 )
|
47 |
-
result.append("".join(tokenizer.convert_ids_to_tokens(generated)));
|
48 |
-
|
49 |
-
return "|".join( result )
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
|
@@ -54,5 +58,4 @@ if __name__ == "__main__":
|
|
54 |
fn=generate,
|
55 |
inputs="text",
|
56 |
outputs="text"
|
57 |
-
).launch()
|
58 |
-
|
|
|
25 |
logits[indices_to_remove] = filter_value
|
26 |
return logits
|
27 |
|
28 |
+
def generate0(input_text):
|
29 |
+
input_ids = [tokenizer.cls_token_id]
|
30 |
+
input_ids.extend( tokenizer.encode(input_text + "-", add_special_tokens=False) )
|
31 |
+
input_ids = torch.tensor( [input_ids] )
|
32 |
+
|
33 |
+
generated = []
|
34 |
+
for _ in range(100):
|
35 |
+
output = model(input_ids)
|
36 |
+
|
37 |
+
next_token_logits = output.logits[0, -1, :]
|
38 |
+
next_token_logits[ tokenizer.convert_tokens_to_ids('[UNK]') ] = -float('Inf')
|
39 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=8, top_p=1)
|
40 |
+
next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
|
41 |
+
if next_token == tokenizer.sep_token_id:
|
42 |
+
break
|
43 |
+
generated.append( next_token.item() )
|
44 |
+
input_ids = torch.cat( (input_ids, next_token.unsqueeze(0)), dim=1 )
|
45 |
+
|
46 |
+
return "".join( tokenizer.convert_ids_to_tokens(generated) )
|
47 |
+
|
48 |
def generate(input_text):
|
49 |
+
result = []
|
50 |
+
for i in range(100):
|
51 |
+
text = generate0(input_text)
|
52 |
+
result.append(text)
|
53 |
+
return "".join( result )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
if __name__ == "__main__":
|
56 |
|
|
|
58 |
fn=generate,
|
59 |
inputs="text",
|
60 |
outputs="text"
|
61 |
+
).launch()
|
|