|
from transformers import BertConfig |
|
|
|
|
|
class PunctuationBertConfig(BertConfig): |
|
r""" |
|
This is the configuration class to store the configuration of a [`PunctuationBertConfig`]. It is based on BERT config |
|
to the specified arguments, defining the model architecture. |
|
Args: |
|
backward_context (`int`, *optional*, defaults to 15): |
|
size of backward context window |
|
forward_context (`int`, *optional*, defaults to 16): |
|
size of forward context window |
|
output_size (`int`, *optional*, defaults to 4): |
|
number of punctuation classes |
|
dropout (`float`, *optional*, defaults to 0.3): |
|
dropout rate |
|
|
|
Examples: |
|
```python |
|
>>> from transformers import BertConfig, BertModel |
|
|
|
>>> # Initializing a BERT google-bert/bert-base-uncased style configuration |
|
>>> configuration = PunctuationBertConfig() |
|
|
|
>>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration |
|
>>> model = BertForPunctuation(configuration) |
|
|
|
>>> # Accessing the model configuration |
|
>>> configuration = model.config |
|
```""" |
|
|
|
def __init__( |
|
self, |
|
backward_context=15, |
|
forward_context=16, |
|
output_size=4, |
|
dropout=0.3, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.backward_context = backward_context |
|
self.forward_context = forward_context |
|
self.output_size = output_size |
|
self.dropout = dropout |
|
|