Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchinfo import summary | |
# Not in use yet | |
class Conv1d_layer(nn.Module): | |
def __init__(self, in_channel, out_channel, kernel_size) -> None: | |
super().__init__() | |
self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size) | |
self.batch_norm = torch.nn.BatchNorm1d(out_channel) | |
self.dropout = nn.Dropout1d(p=0.5) | |
def forward(self, x): | |
x= self.conv(x) | |
x = self.batch_norm(x) | |
x = self.dropout(x) | |
return x | |
class CNN(nn.Module): | |
def __init__(self, ecg_channels=12): | |
super(CNN, self).__init__() | |
self.name = "CNN" | |
self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
self.pool1 = nn.MaxPool1d(2, 2) | |
self.conv2 = nn.Conv1d(16, 32, 5) | |
self.pool2 = nn.MaxPool1d(2, 2) | |
self.conv3 = nn.Conv1d(32, 48, 3) | |
self.pool3 = nn.MaxPool1d(2, 2) | |
self.fc0 = nn.Linear(5856, 512) | |
self.fc1 = nn.Linear(512, 128) | |
self.fc2 = nn.Linear(128, 5) | |
self.activation = nn.ReLU() | |
def forward(self, x, notes=None): | |
x = self.pool1(self.activation(self.conv1(x))) | |
x = self.pool2(self.activation(self.conv2(x))) | |
x = self.pool3(self.activation(self.conv3(x))) | |
x = x.view(x.size(0),-1) | |
x = self.activation(self.fc0(x)) | |
x = self.activation(self.fc1(x)) | |
x = self.fc2(x) | |
x = x.squeeze(1) | |
return x | |
class MMCNN_SUM(nn.Module): | |
def __init__(self, ecg_channels=12): | |
super(MMCNN_SUM, self).__init__() | |
# ECG processing Layers | |
self.name = "MMCNN_SUM" | |
self.conv1 = Conv1d_layer(ecg_channels, 16, 7) | |
self.pool1 = nn.MaxPool1d(2, 2) | |
self.conv2 = Conv1d_layer(16, 32, 5) | |
self.pool2 = nn.MaxPool1d(2, 2) | |
self.conv3 = Conv1d_layer(32, 48, 3) | |
self.pool3 = nn.MaxPool1d(2, 2) | |
self.fc0 = nn.Linear(5856, 512) | |
self.fc1 = nn.Linear(512, 128) | |
self.fc2 = nn.Linear(128, 5) | |
# Clinical Notes Processing Layers | |
self.fc_emb = nn.Linear(768, 128) | |
self.norm = nn.LayerNorm(128) | |
self.activation = nn.ReLU() | |
def forward(self, x, notes): | |
# ECG Processing | |
x = self.pool1(self.activation(self.conv1(x))) | |
x = self.pool2(self.activation(self.conv2(x))) | |
x = self.pool3(self.activation(self.conv3(x))) | |
x = x.view(x.size(0),-1) | |
x = self.activation(self.fc0(x)) | |
x = self.activation(self.fc1(x)) | |
# Notes Processing | |
notes = notes.view(notes.size(0),-1) | |
notes = self.activation(self.fc_emb(notes)) | |
x = self.fc2(self.norm(x + notes)) | |
x = x.squeeze(1) | |
return x | |
class MMCNN_CAT(nn.Module): | |
def __init__(self, ecg_channels=12): | |
super(MMCNN_CAT, self).__init__() | |
# ECG processing Layers | |
self.name = "MMCNN_CAT" | |
self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
self.pool1 = nn.MaxPool1d(2, 2) | |
self.conv2 = nn.Conv1d(16, 32, 5) | |
self.pool2 = nn.MaxPool1d(2, 2) | |
self.conv3 = nn.Conv1d(32, 48, 3) | |
self.pool3 = nn.MaxPool1d(2, 2) | |
self.fc0 = nn.Linear(5856, 512) | |
self.fc1 = nn.Linear(512, 128) | |
self.fc2 = nn.Linear(256, 5) | |
# Clinical Notes Processing Layers | |
self.fc_emb = nn.Linear(768, 128) | |
self.norm = nn.LayerNorm(128) | |
self.activation = nn.ReLU() | |
def forward(self, x, notes): | |
# ECG Processing | |
x = self.pool1(self.activation(self.conv1(x))) | |
x = self.pool2(self.activation(self.conv2(x))) | |
x = self.pool3(self.activation(self.conv3(x))) | |
x = x.view(x.size(0),-1) | |
x = self.activation(self.fc0(x)) | |
x = self.activation(self.fc1(x)) | |
# Notes Processing | |
notes = notes.view(notes.size(0),-1) | |
notes = self.activation(self.fc_emb(notes)) | |
x = self.fc2(torch.cat((x,notes),dim=1)) | |
x = x.squeeze(1) | |
return x | |
class MMCNN_ATT(nn.Module): | |
def __init__(self, ecg_channels=12): | |
super(MMCNN_ATT, self).__init__() | |
# ECG processing Layers | |
self.name = "MMCNN_ATT" | |
self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
self.pool1 = nn.MaxPool1d(2, 2) | |
self.conv2 = nn.Conv1d(16, 32, 5) | |
self.pool2 = nn.MaxPool1d(2, 2) | |
self.conv3 = nn.Conv1d(32, 48, 3) | |
self.pool3 = nn.MaxPool1d(2, 2) | |
self.fc0 = nn.Linear(5856, 512) | |
self.fc1 = nn.Linear(512, 128) | |
self.fc2 = nn.Linear(128, 5) | |
# Clinical Notes Processing Layers | |
self.fc_emb = nn.Linear(768, 128) | |
self.norm1 = nn.LayerNorm(128) | |
self.norm2 = nn.LayerNorm(128) | |
self.attention = nn.MultiheadAttention(128, 8, batch_first=True) | |
self.activation = nn.ReLU() | |
def forward(self, x, notes): | |
# ECG Processing | |
x = self.pool1(self.activation(self.conv1(x))) | |
x = self.pool2(self.activation(self.conv2(x))) | |
x = self.pool3(self.activation(self.conv3(x))) | |
x = x.view(x.size(0),-1) | |
x = self.activation(self.fc0(x)) | |
x = self.activation(self.fc1(x)) | |
x = self.norm1(x) | |
# Notes Processing | |
notes = notes.view(notes.size(0),-1) | |
notes = self.activation(self.fc_emb(notes)) | |
notes = self.norm2(notes) | |
notes=notes.unsqueeze(1) | |
x=x.unsqueeze(1) | |
x,_= self.attention(notes, x, x) | |
x = self.fc2(x) | |
x = x.squeeze(1) | |
return x | |
class MMCNN_SUM_ATT(nn.Module): | |
def __init__(self, ecg_channels=12): | |
super(MMCNN_SUM_ATT, self).__init__() | |
# ECG processing Layers | |
self.name = "MMCNN_SUM_ATT" | |
self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
self.pool1 = nn.MaxPool1d(2, 2) | |
self.conv2 = nn.Conv1d(16, 32, 5) | |
self.pool2 = nn.MaxPool1d(2, 2) | |
self.conv3 = nn.Conv1d(32, 48, 3) | |
self.pool3 = nn.MaxPool1d(2, 2) | |
self.fc0 = nn.Linear(5856, 512) | |
self.fc1 = nn.Linear(512, 128) | |
self.fc2 = nn.Linear(128, 5) | |
# Clinical Notes Processing Layers | |
self.fc_emb = nn.Linear(768, 128) | |
self.norm = nn.LayerNorm(128) | |
self.attention = nn.MultiheadAttention(128, 8, batch_first=True) | |
self.activation = nn.ReLU() | |
def forward(self, x, notes): | |
# ECG Processing | |
x = self.pool1(self.activation(self.conv1(x))) | |
x = self.pool2(self.activation(self.conv2(x))) | |
x = self.pool3(self.activation(self.conv3(x))) | |
x = x.view(x.size(0),-1) | |
x = self.activation(self.fc0(x)) | |
x = self.activation(self.fc1(x)) | |
# Notes Processing | |
notes = notes.view(notes.size(0),-1) | |
notes = self.activation(self.fc_emb(notes)) | |
x = self.norm(x + notes) | |
x=x.unsqueeze(1) | |
# print(x.shape) | |
x,_= self.attention(x, x, x) | |
x = self.fc2(x) | |
x = x.squeeze(1) | |
return x | |
if __name__ == "__main__": | |
model = CNN() | |
# model = Conv1d_layer(12, 16, 7) | |
summary(model, input_size = (1, 12, 1000)) | |