from dataclasses import dataclass | |
from transformers import PretrainedConfig | |
class CustomGPTConfig(PretrainedConfig): | |
""" | |
Configuration class for custom GPT model. | |
""" | |
model_type = "custom_gpt" | |
block_size: int = 768 | |
vocab_size: int = 50257 | |
n_layer: int = 8 | |
n_head: int = 8 | |
n_embd: int = 768 | |
dropout: float = 0.1 |