|
import torch |
|
|
|
|
|
def extract_text_feature(prompt, model, processor, device="cpu"): |
|
"""Extract text features |
|
|
|
Args: |
|
prompt: a single text query |
|
model: OwlViT model |
|
processor: OwlViT processor |
|
device (str, optional): device to run. Defaults to 'cpu'. |
|
""" |
|
device = "cpu" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
with torch.no_grad(): |
|
input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device) |
|
print(input_ids.device) |
|
text_outputs = model.owlvit.text_model( |
|
input_ids=input_ids, |
|
attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
) |
|
text_embeds = text_outputs[1] |
|
text_embeds = model.owlvit.text_projection(text_embeds) |
|
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 |
|
query_embeds = text_embeds |
|
return input_ids, query_embeds |
|
|
|
|
|
def prompt2vec(prompt: str, model, processor): |
|
"""Convert prompt into a computational vector |
|
|
|
Args: |
|
prompt (str): Text to be tokenized |
|
|
|
Returns: |
|
xq: vector from the tokenizer, representing the original prompt |
|
""" |
|
|
|
|
|
input_ids, xq = extract_text_feature(prompt, model, processor) |
|
input_ids = input_ids.detach().cpu().numpy() |
|
xq = xq.detach().cpu().numpy() |
|
return input_ids, xq |
|
|
|
|
|
def tune(clf, X, y, iters=2): |
|
"""Train the Zero-shot Classifier |
|
|
|
Args: |
|
X (numpy.ndarray): Input vectors (retreived vectors) |
|
y (list of floats or numpy.ndarray): Scores given by user |
|
iters (int, optional): iterations of updates to be run |
|
""" |
|
assert len(X) == len(y) |
|
|
|
clf.fit(X, y, iters=iters) |
|
|
|
return clf.get_weights() |
|
|
|
class Classifier: |
|
"""Multi-Class Zero-shot Classifier |
|
This Classifier provides proxy regarding to the user's reaction to the probed images. |
|
The proxy will replace the original query vector generated by prompted vector and finally |
|
give the user a satisfying retrieval result. |
|
|
|
This can be commonly seen in a recommendation system. The classifier will recommend more |
|
precise result as it accumulating user's activity. |
|
|
|
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes |
|
and the last one takes the negative one. |
|
""" |
|
|
|
def __init__(self, client, obj_db:str, xq: list): |
|
init_weight = torch.Tensor(xq) |
|
self.num_class = xq.shape[0] |
|
self.DIMS = xq.shape[1] |
|
|
|
self.weight = init_weight |
|
self.client = client |
|
self.obj_db = obj_db |
|
|
|
def fit(self, X: list, y: list, iters: int = 5): |
|
|
|
xq_s = [ |
|
f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]" |
|
for _xq in self.get_weights().tolist() |
|
] |
|
|
|
for _ in range(iters): |
|
|
|
grad = [] |
|
|
|
|
|
self.weight.data /= torch.norm( |
|
self.weight.data, p=2, dim=-1, keepdim=True |
|
) |
|
for n in range(self.num_class): |
|
|
|
labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]]))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_q_str = f""" |
|
SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad |
|
FROM ( |
|
SELECT groupArray(arrayPopBack(prelogit)) AS X, |
|
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT |
|
FROM {self.obj_db} WHERE obj_id IN {objs})""" |
|
grad_ = [r['grad'] for r in self.client.query(grad_q_str).named_results()][0] |
|
grad.append(torch.as_tensor(grad_)) |
|
|
|
grad = torch.stack(grad, dim=0) |
|
self.weight -= 0.1 * grad |
|
|
|
def get_weights(self): |
|
xq = self.weight.detach().numpy() |
|
return xq |
|
|
|
|
|
class SplitLayer(torch.nn.Module): |
|
def forward(self, x): |
|
return torch.split(x, 1, dim=-1) |
|
|