brainiac / modeling_brainiac.py
Divytak's picture
Upload modeling_brainiac.py with huggingface_hub
3010446 verified
raw
history blame
1.27 kB
import torch
import torch.nn as nn
from typing import Optional
from transformers import PreTrainedModel
from .configuration_brainiac import BrainiacConfig
from monai.networks.nets import resnet50
class BrainiacModel(PreTrainedModel):
config_class = BrainiacConfig
base_model_prefix = "brainiac"
def __init__(self, config: BrainiacConfig):
super().__init__(config)
self.config = config
# Initialize ResNet50 from MONAI
self.resnet = resnet50(pretrained=False)
# Modify first conv layer for 3D input
self.resnet.conv1 = nn.Conv3d(
config.in_channels,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False
)
# Replace final FC layer with Identity
self.resnet.fc = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.resnet(x)
def _init_weights(self, module):
if isinstance(module, nn.Conv3d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, (nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)