SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
import torch
import torch.nn as nn
import timm
import numpy as np
class twins_svt_large(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.svt = timm.create_model('twins_svt_large', pretrained=pretrained)
del self.svt.head
del self.svt.patch_embeds[2]
del self.svt.patch_embeds[2]
del self.svt.blocks[2]
del self.svt.blocks[2]
del self.svt.pos_block[2]
del self.svt.pos_block[2]
def forward(self, x, data=None, layer=2):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j==0:
x = pos_blk(x, size)
if i < len(self.svt.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if i == layer-1:
break
return x
def compute_params(self, layer=2):
num = 0
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
for param in embed.parameters():
num += np.prod(param.size())
for param in drop.parameters():
num += np.prod(param.size())
for param in blocks.parameters():
num += np.prod(param.size())
for param in pos_blk.parameters():
num += np.prod(param.size())
if i == layer-1:
break
for param in self.svt.head.parameters():
num += np.prod(param.size())
return num
class twins_svt_large_context(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.svt = timm.create_model('twins_svt_large_context', pretrained=pretrained)
def forward(self, x, data=None, layer=2):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j==0:
x = pos_blk(x, size)
if i < len(self.svt.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
if i == layer-1:
break
return x
if __name__ == "__main__":
m = twins_svt_large()
input = torch.randn(2, 3, 400, 800)
out = m.extract_feature(input)
print(out.shape)