Divyasreepat commited on
Commit
023c32c
1 Parent(s): bf6cda7

Update README.md with new model card content

Browse files
Files changed (1) hide show
  1. README.md +176 -13
README.md CHANGED
@@ -1,16 +1,179 @@
1
  ---
2
  library_name: keras-hub
3
  ---
4
- This is a [`GPT2` model](https://keras.io/api/keras_hub/models/gpt2) uploaded using the KerasHub library and can be used with JAX, TensorFlow, and PyTorch backends.
5
- Model config:
6
- * **name:** gpt2_backbone
7
- * **trainable:** True
8
- * **vocabulary_size:** 50257
9
- * **num_layers:** 12
10
- * **num_heads:** 12
11
- * **hidden_dim:** 768
12
- * **intermediate_dim:** 3072
13
- * **dropout:** 0.1
14
- * **max_sequence_length:** 1024
15
-
16
- This model card has been generated automatically and should be completed by the model author. See [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for more information.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: keras-hub
3
  ---
4
+ ### Model Overview
5
+ GPT-2 is a language model published by OpenAI. Models are fine tuned on WebText, and range in size from 125 million to 1.5 billion parameters. See the model card below for benchmarks, data sources, and intended use cases.
6
+
7
+ Weights are released under the [MIT License](https://opensource.org/license/mit). Keras model code is released under the [Apache 2 License](https://github.com/keras-team/keras-hub/blob/master/LICENSE).
8
+
9
+ ## Links
10
+
11
+ * [GPT-2 Quickstart Notebook](https://www.kaggle.com/code/gabrielrasskin/gpt-2-quickstart)
12
+ * [GPT-2 API Documentation](https://keras.io/api/keras_hub/models/gpt2/)
13
+ * [GPT-2 Model Card](https://github.com/openai/gpt-2/blob/master/model_card.md)
14
+ * [KerasHub Beginner Guide](https://keras.io/guides/keras_hub/getting_started/)
15
+ * [KerasHub Model Publishing Guide](https://keras.io/guides/keras_hub/upload/)
16
+
17
+ ## Installation
18
+
19
+ Keras and KerasHub can be installed with:
20
+
21
+ ```
22
+ pip install -U -q keras-hub
23
+ pip install -U -q keras>=3
24
+ ```
25
+
26
+ Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instruction on installing them in another environment see the [Keras Getting Started](https://keras.io/getting_started/) page.
27
+
28
+ ## Presets
29
+
30
+ The following model checkpoints are provided by the Keras team. Full code examples for each are available below.
31
+
32
+ | Preset name | Parameters | Description |
33
+ |----------------------------|------------|------------------------------------------------------------------------------------------------------|
34
+ | `gpt2_base_en` | 124.44M | 12-layer GPT-2 model where case is maintained. Trained on WebText. |
35
+ | `gpt2_medium_en` | 354.82M | 24-layer GPT-2 model where case is maintained. Trained on WebText. |
36
+ | `gpt2_large_en` | 774.03M | 36-layer GPT-2 model where case is maintained. Trained on WebText. |
37
+ | `gpt2_extra_large_en` | 1.56B | 48-layer GPT-2 model where case is maintained. Trained on WebText. |
38
+ | `gpt2_base_en_cnn_dailymail` | 124.44M | 12-layer GPT-2 model where case is maintained. Finetuned on the CNN/DailyMail summarization dataset. |
39
+
40
+ ## Prompts
41
+
42
+ GPT-2 models are fine tuned on WebText. Prompting should follow text completion formatting. See the following for an example:
43
+
44
+ ```python
45
+ prompt = "Keras is a "
46
+ ```
47
+
48
+ would have GPT-2 aim to complete the sentence.
49
+
50
+ ### Example Usage
51
+ ```python
52
+ import keras
53
+ import keras_hub
54
+ import numpy as np
55
+ ```
56
+
57
+ Use `generate()` to do text generation.
58
+ ```python
59
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en_cnn_dailymail")
60
+ gpt2_lm.generate("I want to say", max_length=30)
61
+
62
+ # Generate with batched prompts.
63
+ gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
64
+ ```
65
+
66
+ Compile the `generate()` function with a custom sampler.
67
+ ```python
68
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en_cnn_dailymail")
69
+ gpt2_lm.compile(sampler="greedy")
70
+ gpt2_lm.generate("I want to say", max_length=30)
71
+
72
+ gpt2_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2))
73
+ gpt2_lm.generate("I want to say", max_length=30)
74
+ ```
75
+
76
+ Use `generate()` without preprocessing.
77
+ ```python
78
+ # Prompt the model with `5338, 318` (the token ids for `"Who is"`).
79
+ # Use `"padding_mask"` to indicate values that should not be overridden.
80
+ prompt = {
81
+ "token_ids": np.array([[5338, 318, 0, 0, 0]] * 2),
82
+ "padding_mask": np.array([[1, 1, 0, 0, 0]] * 2),
83
+ }
84
+
85
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
86
+ "gpt2_base_en_cnn_dailymail",
87
+ preprocessor=None,
88
+ )
89
+ gpt2_lm.generate(prompt)
90
+ ```
91
+
92
+ Call `fit()` on a single batch.
93
+ ```python
94
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
95
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
96
+ gpt2_lm.fit(x=features, batch_size=2)
97
+ ```
98
+
99
+ Call `fit()` without preprocessing.
100
+ ```python
101
+ x = {
102
+ "token_ids": np.array([[50256, 1, 2, 3, 4]] * 2),
103
+ "padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
104
+ }
105
+ y = np.array([[1, 2, 3, 4, 50256]] * 2)
106
+ sw = np.array([[1, 1, 1, 1, 1]] * 2)
107
+
108
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
109
+ "gpt2_base_en_cnn_dailymail",
110
+ preprocessor=None,
111
+ )
112
+ gpt2_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
113
+ ```
114
+
115
+ ## Example Usage with Hugging Face URI
116
+
117
+ ```python
118
+ import keras
119
+ import keras_hub
120
+ import numpy as np
121
+ ```
122
+
123
+ Use `generate()` to do text generation.
124
+ ```python
125
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset("hf://keras/gpt2_base_en_cnn_dailymail")
126
+ gpt2_lm.generate("I want to say", max_length=30)
127
+
128
+ # Generate with batched prompts.
129
+ gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
130
+ ```
131
+
132
+ Compile the `generate()` function with a custom sampler.
133
+ ```python
134
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset("hf://keras/gpt2_base_en_cnn_dailymail")
135
+ gpt2_lm.compile(sampler="greedy")
136
+ gpt2_lm.generate("I want to say", max_length=30)
137
+
138
+ gpt2_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2))
139
+ gpt2_lm.generate("I want to say", max_length=30)
140
+ ```
141
+
142
+ Use `generate()` without preprocessing.
143
+ ```python
144
+ # Prompt the model with `5338, 318` (the token ids for `"Who is"`).
145
+ # Use `"padding_mask"` to indicate values that should not be overridden.
146
+ prompt = {
147
+ "token_ids": np.array([[5338, 318, 0, 0, 0]] * 2),
148
+ "padding_mask": np.array([[1, 1, 0, 0, 0]] * 2),
149
+ }
150
+
151
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
152
+ "hf://keras/gpt2_base_en_cnn_dailymail",
153
+ preprocessor=None,
154
+ )
155
+ gpt2_lm.generate(prompt)
156
+ ```
157
+
158
+ Call `fit()` on a single batch.
159
+ ```python
160
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
161
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
162
+ gpt2_lm.fit(x=features, batch_size=2)
163
+ ```
164
+
165
+ Call `fit()` without preprocessing.
166
+ ```python
167
+ x = {
168
+ "token_ids": np.array([[50256, 1, 2, 3, 4]] * 2),
169
+ "padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
170
+ }
171
+ y = np.array([[1, 2, 3, 4, 50256]] * 2)
172
+ sw = np.array([[1, 1, 1, 1, 1]] * 2)
173
+
174
+ gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
175
+ "hf://keras/gpt2_base_en_cnn_dailymail",
176
+ preprocessor=None,
177
+ )
178
+ gpt2_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
179
+ ```