|
import torch |
|
|
|
|
|
def extract_text_feature(prompt, model, processor, device='cpu'): |
|
"""Extract text features |
|
|
|
Args: |
|
prompt: a single text query |
|
model: OwlViT model |
|
processor: OwlViT processor |
|
device (str, optional): device to run. Defaults to 'cpu'. |
|
""" |
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
with torch.no_grad(): |
|
input_ids = torch.as_tensor(processor(text=prompt)[ |
|
'input_ids']).to(device) |
|
print(input_ids.device) |
|
text_outputs = model.owlvit.text_model( |
|
input_ids=input_ids, |
|
attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
) |
|
text_embeds = text_outputs[1] |
|
text_embeds = model.owlvit.text_projection(text_embeds) |
|
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 |
|
query_embeds = text_embeds |
|
return input_ids, query_embeds |
|
|
|
|
|
def prompt2vec(prompt: str, model, processor): |
|
""" Convert prompt into a computational vector |
|
|
|
Args: |
|
prompt (str): Text to be tokenized |
|
|
|
Returns: |
|
xq: vector from the tokenizer, representing the original prompt |
|
""" |
|
|
|
|
|
input_ids, xq = extract_text_feature(prompt, model, processor) |
|
input_ids = input_ids.detach().cpu().numpy() |
|
xq = xq.detach().cpu().numpy() |
|
return input_ids, xq |
|
|
|
|
|
def tune(clf, X, y, iters=2): |
|
""" Train the Zero-shot Classifier |
|
|
|
Args: |
|
X (numpy.ndarray): Input vectors (retreived vectors) |
|
y (list of floats or numpy.ndarray): Scores given by user |
|
iters (int, optional): iterations of updates to be run |
|
""" |
|
assert len(X) == len(y) |
|
|
|
clf.fit(X, y, iters=iters) |
|
|
|
return clf.get_weights() |
|
|
|
|
|
class Classifier: |
|
"""Multi-Class Zero-shot Classifier |
|
This Classifier provides proxy regarding to the user's reaction to the probed images. |
|
The proxy will replace the original query vector generated by prompted vector and finally |
|
give the user a satisfying retrieval result. |
|
|
|
This can be commonly seen in a recommendation system. The classifier will recommend more |
|
precise result as it accumulating user's activity. |
|
|
|
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes |
|
and the last one takes the negative one. |
|
""" |
|
|
|
def __init__(self, xq: list): |
|
init_weight = torch.Tensor(xq) |
|
self.num_class = xq.shape[0] |
|
DIMS = xq.shape[1] |
|
|
|
self.model = torch.nn.Linear(DIMS, self.num_class, bias=False) |
|
|
|
self.model.weight = torch.nn.Parameter(init_weight) |
|
|
|
self.loss = torch.nn.BCEWithLogitsLoss() |
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) |
|
|
|
def fit(self, X: list, y: list, iters: int = 5): |
|
|
|
X = torch.Tensor(X) |
|
X /= torch.norm(X, p=2, dim=-1, keepdim=True) |
|
y = torch.Tensor(y).long() |
|
|
|
non_ind = y > self.num_class |
|
y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float() |
|
y[non_ind] = 0 |
|
for i in range(iters): |
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True) |
|
|
|
out = self.model(X) |
|
|
|
loss = self.loss(out, y) |
|
|
|
loss.backward() |
|
|
|
self.optimizer.step() |
|
|
|
def get_weights(self): |
|
xq = self.model.weight.detach().numpy() |
|
return xq |
|
|
|
class SplitLayer(torch.nn.Module): |
|
def forward(self, x): |
|
return torch.split(x, 1, dim=-1) |
|
|