|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
from open_clip.model import _build_vision_tower |
|
|
|
|
|
class CLIP(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
model_name = 'convnext_large' |
|
|
|
vision_cfg = {'timm_model_name': model_name, 'timm_model_pretrained': False, 'timm_pool': '', 'timm_proj': 'mlp', 'timm_drop': 0.0, 'timm_drop_path': 0.1, 'image_size': 320} |
|
self.visual = _build_vision_tower(embed_dim=768, vision_cfg=vision_cfg, quick_gelu=False) |
|
|
|
self.eval() |
|
self.freeze_everything() |
|
|
|
def freeze_everything(self): |
|
for param in self.visual.parameters(): |
|
param.requires_grad = False |
|
|
|
def extract_features(self, x): |
|
out = {} |
|
x = x.to(self.visual.trunk.stem.state_dict()['1.bias'].dtype) |
|
x = self.visual.trunk.stem(x) |
|
out['stem'] = x.contiguous() |
|
for i in range(4): |
|
x = self.visual.trunk.stages[i](x) |
|
out[f'res{i+2}'] = x.contiguous() |
|
|
|
x = self.visual.trunk.norm_pre(x) |
|
out['clip_vis_dense'] = x.contiguous() |
|
return out |
|
|
|
def forward(self, x): |
|
self.eval() |
|
with torch.no_grad(): |
|
return self.extract_features(x) |
|
|