RANomAly / prompt_ensemble.py
Amirparsa-Sal
Add codes
5d1f0ae
import os
from typing import Union, List
from pkg_resources import packaging
import torch
import numpy as np
from AnomalyCLIP_lib.simple_tokenizer import SimpleTokenizer as _Tokenizer
# from open_clip import tokenizer
# simple_tokenizer = tokenizer.SimpleTokenizer()
from copy import deepcopy
import torch.nn as nn
_tokenizer = _Tokenizer()
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def encode_text_with_prompt_ensemble(model, texts, device):
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
text_features = []
for i in range(len(prompt_state)):
prompted_state = [state.format(texts[0]) for state in prompt_state[i]]
prompted_sentence = []
for s in prompted_state:
for template in prompt_templates:
prompted_sentence.append(template.format(s))
prompted_sentence = tokenize(prompted_sentence)
class_embeddings = model.encode_text(prompted_sentence.to(device))
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1).to(device).t()
return text_features
def _get_clones(module, N):
return nn.ModuleList([deepcopy(module) for i in range(N)])
class AnomalyCLIP_PromptLearner(nn.Module):
def __init__(self, clip_model, design_details):
super().__init__()
classnames = ["object"]
self.n_cls = len(classnames)
self.n_ctx = design_details["Prompt_length"]
n_ctx_pos = self.n_ctx
n_ctx_neg = self.n_ctx
self.text_encoder_n_ctx = design_details["learnabel_text_embedding_length"]
ctx_init_pos = ""
ctx_init_neg = ""
dtype = clip_model.transformer.get_cast_dtype()
ctx_dim = clip_model.ln_final.weight.shape[0]
self.classnames = classnames
self.state_normal_list = [
"{}",
]
self.state_anomaly_list = [
"damaged {}",
]
normal_num = len(self.state_normal_list)
anormaly_num = len(self.state_anomaly_list)
self.normal_num = normal_num
self.anormaly_num = anormaly_num
if ctx_init_pos and ctx_init_neg:
# use given words to initialize context vectors
ctx_init_pos = ctx_init_pos.replace("_", " ")
ctx_init_neg = ctx_init_neg.replace("_", " ")
n_ctx_pos = len(ctx_init_pos.split(" "))
n_ctx_neg = len(ctx_init_neg.split(" "))
#初始化text成bpd编码
prompt_pos = tokenize(ctx_init_pos)
prompt_neg = tokenize(ctx_init_neg)
with torch.no_grad():
#生成相应的text embedding
embedding_pos = clip_model.token_embedding(prompt_pos).type(dtype)
embedding_neg = clip_model.token_embedding(prompt_neg).type(dtype)
#这些是去除出来EOS 和 # CLS, EOS, 获得可学习的textual prompt
ctx_vectors_pos = embedding_pos[0, 1: 1 + n_ctx_pos, :]
ctx_vectors_neg = embedding_neg[0, 1: 1 + n_ctx_neg, :]
prompt_prefix_pos = ctx_init_pos
prompt_prefix_neg = ctx_init_neg
if True:
ctx_vectors_pos_ = []
ctx_vectors_neg_ = []
for _ in range(self.n_cls):
ctx_vectors_pos_.append(deepcopy(ctx_vectors_pos))
ctx_vectors_neg_.append(deepcopy(ctx_vectors_neg))
ctx_vectors_pos = torch.stack(ctx_vectors_pos_, dim=0)
ctx_vectors_neg = torch.stack(ctx_vectors_neg_, dim=0)
else:
# Random Initialization
if True:
print("Initializing class-specific contexts")
#这里是cls是类的个数,n_ctx_pos代表learnable token的长度,ctx_dim表示prompt的dimension
ctx_vectors_pos = torch.empty(self.n_cls, self.normal_num, n_ctx_pos, ctx_dim, dtype=dtype)
ctx_vectors_neg = torch.empty(self.n_cls, self.anormaly_num, n_ctx_neg, ctx_dim, dtype=dtype)
else:
print("Initializing a generic context")
ctx_vectors_pos = torch.empty(n_ctx_pos, ctx_dim, dtype=dtype)
ctx_vectors_neg = torch.empty(n_ctx_neg, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors_pos, std=0.02)
nn.init.normal_(ctx_vectors_neg, std=0.02)
prompt_prefix_pos = " ".join(["X"] * n_ctx_pos)
prompt_prefix_neg = " ".join(["X"] * n_ctx_neg)
self.compound_prompts_depth = design_details["learnabel_text_embedding_depth"]
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(self.text_encoder_n_ctx, ctx_dim))
for _ in range(self.compound_prompts_depth - 1)])
for single_para in self.compound_prompts_text:
print("single_para", single_para.shape)
nn.init.normal_(single_para, std=0.02)
single_layer = nn.Linear(ctx_dim, 896)
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
self.ctx_pos = nn.Parameter(ctx_vectors_pos) # to be optimized
self.ctx_neg = nn.Parameter(ctx_vectors_neg) # to be optimized
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts_pos = [prompt_prefix_pos + " " + template.format(name)+ "." for template in self.state_normal_list for name in classnames]
prompts_neg = [prompt_prefix_neg + " " + template.format(name)+ "." for template in self.state_anomaly_list for name in classnames]
tokenized_prompts_pos = []
tokenized_prompts_neg = []
for p_pos in prompts_pos:
tokenized_prompts_pos.append(tokenize(p_pos))
for p_neg in prompts_neg:
tokenized_prompts_neg.append(tokenize(p_neg))
tokenized_prompts_pos = torch.cat(tokenized_prompts_pos)
tokenized_prompts_neg = torch.cat(tokenized_prompts_neg)
#生成相应的text embedding
with torch.no_grad():
embedding_pos = clip_model.token_embedding(tokenized_prompts_pos).type(dtype)
embedding_neg = clip_model.token_embedding(tokenized_prompts_neg).type(dtype)
n, l, d = embedding_pos.shape
print("embedding_pos", embedding_pos.shape)
embedding_pos = embedding_pos.reshape(normal_num, self.n_cls, l, d).permute(1, 0, 2, 3)
embedding_neg = embedding_neg.reshape(anormaly_num, self.n_cls, l, d).permute(1, 0, 2, 3)
self.register_buffer("token_prefix_pos", embedding_pos[:, :, :1, :] )
self.register_buffer("token_suffix_pos", embedding_pos[:, :,1 + n_ctx_pos:, :])
self.register_buffer("token_prefix_neg", embedding_neg[:,:, :1, :])
self.register_buffer("token_suffix_neg", embedding_neg[:, :, 1 + n_ctx_neg:, :])
n, d = tokenized_prompts_pos.shape
tokenized_prompts_pos = tokenized_prompts_pos.reshape(normal_num, self.n_cls, d).permute(1, 0, 2)
n, d = tokenized_prompts_neg.shape
tokenized_prompts_neg = tokenized_prompts_neg.reshape(anormaly_num, self.n_cls, d).permute(1, 0, 2)
self.n_ctx_pos = n_ctx_pos
self.n_ctx_neg = n_ctx_neg
# tokenized_prompts = torch.cat([tokenized_prompts_pos, tokenized_prompts_neg], dim=0) # torch.Tensor
self.register_buffer("tokenized_prompts_pos", tokenized_prompts_pos)
self.register_buffer("tokenized_prompts_neg", tokenized_prompts_neg)
print("tokenized_prompts shape", self.tokenized_prompts_pos.shape, self.tokenized_prompts_neg.shape)
def forward(self, cls_id =None):
ctx_pos = self.ctx_pos
ctx_neg = self.ctx_neg
ctx_pos = self.ctx_pos
ctx_neg = self.ctx_neg
# print("shape", self.ctx_pos[0:1].shape, ctx_pos.shape)
prefix_pos = self.token_prefix_pos
prefix_neg = self.token_prefix_neg
suffix_pos = self.token_suffix_pos
suffix_neg = self.token_suffix_neg
# print(prefix_pos.shape, prefix_neg.shape)
prompts_pos = torch.cat(
[
# N(the number of template), 1, dim
prefix_pos, # (n_cls, 1, dim)
ctx_pos, # (n_cls, n_ctx, dim)
suffix_pos, # (n_cls, *, dim)
],
dim=2,
)
prompts_neg = torch.cat(
[
prefix_neg, # (n_cls, 1, dim)
ctx_neg, # (n_cls, n_ctx, dim)
suffix_neg, # (n_cls, *, dim)
],
dim=2,
)
_, _, l, d = prompts_pos.shape
prompts_pos = prompts_pos.reshape(-1, l, d)
_, _, l, d = prompts_neg.shape
prompts_neg = prompts_neg.reshape(-1, l, d)
prompts = torch.cat([prompts_pos, prompts_neg], dim=0)
_, l, d = self.tokenized_prompts_pos.shape
tokenized_prompts_pos = self.tokenized_prompts_pos.reshape(-1, d)
_, l, d = self.tokenized_prompts_neg.shape
tokenized_prompts_neg = self.tokenized_prompts_neg.reshape(-1, d)
tokenized_prompts = torch.cat((tokenized_prompts_pos, tokenized_prompts_neg), dim = 0)
return prompts, tokenized_prompts, self.compound_prompts_text