import torch def extract_text_feature(prompt, model, processor, device="cpu"): """Extract text features Args: prompt: a single text query model: OwlViT model processor: OwlViT processor device (str, optional): device to run. Defaults to 'cpu'. """ device = "cpu" if torch.cuda.is_available(): device = "cuda" with torch.no_grad(): input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device) print(input_ids.device) text_outputs = model.owlvit.text_model( input_ids=input_ids, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ) text_embeds = text_outputs[1] text_embeds = model.owlvit.text_projection(text_embeds) text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 query_embeds = text_embeds return input_ids, query_embeds def prompt2vec(prompt: str, model, processor): """Convert prompt into a computational vector Args: prompt (str): Text to be tokenized Returns: xq: vector from the tokenizer, representing the original prompt """ # inputs = tokenizer(prompt, return_tensors='pt') # out = clip.get_text_features(**inputs) input_ids, xq = extract_text_feature(prompt, model, processor) input_ids = input_ids.detach().cpu().numpy() xq = xq.detach().cpu().numpy() return input_ids, xq def tune(clf, X, y, iters=2): """Train the Zero-shot Classifier Args: X (numpy.ndarray): Input vectors (retreived vectors) y (list of floats or numpy.ndarray): Scores given by user iters (int, optional): iterations of updates to be run """ assert len(X) == len(y) # train the classifier clf.fit(X, y, iters=iters) # extract new vector return clf.get_weights() class Classifier: """Multi-Class Zero-shot Classifier This Classifier provides proxy regarding to the user's reaction to the probed images. The proxy will replace the original query vector generated by prompted vector and finally give the user a satisfying retrieval result. This can be commonly seen in a recommendation system. The classifier will recommend more precise result as it accumulating user's activity. This is a multiclass classifier. For N queries it will set the all queries to the first-N classes and the last one takes the negative one. """ def __init__(self, client, obj_db:str, xq: list): init_weight = torch.Tensor(xq) self.num_class = xq.shape[0] self.DIMS = xq.shape[1] # convert initial query `xq` to tensor parameter to init weights self.weight = init_weight self.client = client self.obj_db = obj_db def fit(self, X: list, y: list, iters: int = 5): # convert X and y to tensor xq_s = [ f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]" for _xq in self.get_weights().tolist() ] for _ in range(iters): # zero gradients grad = [] # Normalize the weight before inference # This will constrain the gradient or you will have an explosion on query vector self.weight.data /= torch.norm( self.weight.data, p=2, dim=-1, keepdim=True ) for n in range(self.num_class): # select all training sample and create labels labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]]))) # NOTE from @fangruil # Use SQL to calculate the gradient # For binary cross entropy we have # g = (1/(1+\exp(-XW))-Y)^TX # To simplify the query, we separated # the calculation into class numbers grad_q_str = f""" SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad FROM ( SELECT groupArray(arrayPopBack(prelogit)) AS X, groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT FROM {self.obj_db} WHERE obj_id IN {objs})""" grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad'])) # update weights grad = torch.stack(grad, dim=0) self.weight -= 0.1 * grad def get_weights(self): xq = self.weight.detach().numpy() return xq class SplitLayer(torch.nn.Module): def forward(self, x): return torch.split(x, 1, dim=-1)