robowaifudev commited on
Commit
5ee52db
·
1 Parent(s): c0fd3c3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -5
README.md CHANGED
@@ -1,3 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  <!---
2
  # ##############################################################################################
3
  #
@@ -48,29 +64,33 @@ model = GPT2LMHeadModel.from_pretrained("robowaifudev/megatron-gpt2-345m")
48
 
49
  if torch.cuda.is_available():
50
  device = torch.device("cuda")
51
- model.to(device)
52
  model.half()
53
  else:
54
  device = torch.device("cpu")
 
55
  model.eval()
56
 
57
  # Generate
58
- text = "Hello world!"
59
- input_ids = tokenizer.encode(text, return_tensors="pt")
60
  output = model.generate(
61
  input_ids=input_ids,
62
- max_length=len(input_ids) + 32,
63
  do_sample=True,
64
  top_k=64,
65
  top_p=0.9,
66
  temperature=0.8,
67
- num_return_sequences=1
 
68
  )
69
 
70
  # Output the text.
 
 
71
  for i, sentence in enumerate(output):
72
  text = tokenizer.decode(sentence, clean_up_tokenization_spaces=True)
73
  print(f"{i}:", text)
 
74
  ```
75
 
76
  # Original code
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - gpt2
6
+ license: apache-2.0
7
+ datasets:
8
+ - wikitext
9
+ - openwebtext
10
+ - cc-stories
11
+ metrics:
12
+ - type: wikitext
13
+ value: 19.31
14
+ name: WikiText-103
15
+ ---
16
+
17
  <!---
18
  # ##############################################################################################
19
  #
 
64
 
65
  if torch.cuda.is_available():
66
  device = torch.device("cuda")
 
67
  model.half()
68
  else:
69
  device = torch.device("cpu")
70
+ model.to(device)
71
  model.eval()
72
 
73
  # Generate
74
+ prompt = "It was a bright cold day in April, and the clocks were striking thirteen. Winston Smith,"
75
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
76
  output = model.generate(
77
  input_ids=input_ids,
78
+ max_length=len(input_ids) + 128,
79
  do_sample=True,
80
  top_k=64,
81
  top_p=0.9,
82
  temperature=0.8,
83
+ num_return_sequences=2,
84
+ repetition_penalty=1.025
85
  )
86
 
87
  # Output the text.
88
+ print("Prompt:", prompt)
89
+ print("*" * 3)
90
  for i, sentence in enumerate(output):
91
  text = tokenizer.decode(sentence, clean_up_tokenization_spaces=True)
92
  print(f"{i}:", text)
93
+ print("*" * 3)
94
  ```
95
 
96
  # Original code