File size: 4,835 Bytes
3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
def extract_text_feature(prompt, model, processor, device="cpu"):
"""Extract text features
Args:
prompt: a single text query
model: OwlViT model
processor: OwlViT processor
device (str, optional): device to run. Defaults to 'cpu'.
"""
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
with torch.no_grad():
input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device)
print(input_ids.device)
text_outputs = model.owlvit.text_model(
input_ids=input_ids,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
)
text_embeds = text_outputs[1]
text_embeds = model.owlvit.text_projection(text_embeds)
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
query_embeds = text_embeds
return input_ids, query_embeds
def prompt2vec(prompt: str, model, processor):
"""Convert prompt into a computational vector
Args:
prompt (str): Text to be tokenized
Returns:
xq: vector from the tokenizer, representing the original prompt
"""
# inputs = tokenizer(prompt, return_tensors='pt')
# out = clip.get_text_features(**inputs)
input_ids, xq = extract_text_feature(prompt, model, processor)
input_ids = input_ids.detach().cpu().numpy()
xq = xq.detach().cpu().numpy()
return input_ids, xq
def tune(clf, X, y, iters=2):
"""Train the Zero-shot Classifier
Args:
X (numpy.ndarray): Input vectors (retreived vectors)
y (list of floats or numpy.ndarray): Scores given by user
iters (int, optional): iterations of updates to be run
"""
assert len(X) == len(y)
# train the classifier
clf.fit(X, y, iters=iters)
# extract new vector
return clf.get_weights()
class Classifier:
"""Multi-Class Zero-shot Classifier
This Classifier provides proxy regarding to the user's reaction to the probed images.
The proxy will replace the original query vector generated by prompted vector and finally
give the user a satisfying retrieval result.
This can be commonly seen in a recommendation system. The classifier will recommend more
precise result as it accumulating user's activity.
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
and the last one takes the negative one.
"""
def __init__(self, client, obj_db:str, xq: list):
init_weight = torch.Tensor(xq)
self.num_class = xq.shape[0]
self.DIMS = xq.shape[1]
# convert initial query `xq` to tensor parameter to init weights
self.weight = init_weight
self.client = client
self.obj_db = obj_db
def fit(self, X: list, y: list, iters: int = 5):
# convert X and y to tensor
xq_s = [
f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]"
for _xq in self.get_weights().tolist()
]
for _ in range(iters):
# zero gradients
grad = []
# Normalize the weight before inference
# This will constrain the gradient or you will have an explosion on query vector
self.weight.data /= torch.norm(
self.weight.data, p=2, dim=-1, keepdim=True
)
for n in range(self.num_class):
# select all training sample and create labels
labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]])))
# NOTE from @fangruil
# Use SQL to calculate the gradient
# For binary cross entropy we have
# g = (1/(1+\exp(-XW))-Y)^TX
# To simplify the query, we separated
# the calculation into class numbers
grad_q_str = f"""
SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
FROM (
SELECT groupArray(arrayPopBack(prelogit)) AS X,
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
FROM {self.obj_db} WHERE obj_id IN {objs})"""
grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
# update weights
grad = torch.stack(grad, dim=0)
self.weight -= 0.1 * grad
def get_weights(self):
xq = self.weight.detach().numpy()
return xq
class SplitLayer(torch.nn.Module):
def forward(self, x):
return torch.split(x, 1, dim=-1)
|