|
from transformers import PretrainedConfig |
|
|
|
|
|
class VISTA3DConfig(PretrainedConfig): |
|
"""Configuration class for vista3d""" |
|
|
|
model_type = "VISTA3D" |
|
|
|
def __init__(self, encoder_embed_dim: int = 48, input_channels: int = 1, **kwargs): |
|
""" |
|
Set the hyperparameters for the VISTA3D model. |
|
|
|
Parameters: |
|
input_channels: channel of input images. |
|
encoder_embed_dim: the encoder_embed_dim of the VISTA3D model. |
|
""" |
|
self.input_channels = input_channels |
|
self.encoder_embed_dim = encoder_embed_dim |
|
super().__init__(**kwargs) |
|
|