|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from . import eva_vit |
|
from .transformer import text_transformer |
|
|
|
class CLIP(nn.Module): |
|
def __init__( |
|
self, |
|
vision_model: str = 'eva_base_p16', |
|
): |
|
super().__init__() |
|
self.visual = eva_vit.__dict__[vision_model]() |
|
self.text = text_transformer() |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
def encode_image(self, image, normalize: bool = False): |
|
features = self.visual(image) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def encode_text(self, text, normalize: bool = False): |
|
features = self.text(text) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def forward(self, image, text): |
|
image_features = self.encode_image(image, normalize=True) |
|
text_features = self.encode_text(text, normalize=True) |
|
return image_features, text_features, self.logit_scale.exp() |