Zero-Shot Classification
UTA / eva_vit_model /uta_clip.py
jihao
update eval files
f73bf08
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import eva_vit
from .transformer import text_transformer
class CLIP(nn.Module):
def __init__(
self,
vision_model: str = 'eva_base_p16',
):
super().__init__()
self.visual = eva_vit.__dict__[vision_model]()
self.text = text_transformer()
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()