File size: 4,835 Bytes
3f1124e
 
 
88c0383
3f1124e
 
 
 
 
 
 
 
88c0383
3f1124e
88c0383
3f1124e
88c0383
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88c0383
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88c0383
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88c0383
3f1124e
88c0383
3f1124e
 
 
 
88c0383
3f1124e
 
88c0383
3f1124e
88c0383
 
 
3f1124e
 
 
88c0383
 
 
 
 
 
3f1124e
88c0383
3f1124e
 
88c0383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f1124e
88c0383
 
3f1124e
 
88c0383
3f1124e
88c0383
 
 
3f1124e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch


def extract_text_feature(prompt, model, processor, device="cpu"):
    """Extract text features

    Args:
        prompt: a single text query
        model: OwlViT model
        processor: OwlViT processor
        device (str, optional): device to run. Defaults to 'cpu'.
    """
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    with torch.no_grad():
        input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device)
        print(input_ids.device)
        text_outputs = model.owlvit.text_model(
            input_ids=input_ids,
            attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
        )
        text_embeds = text_outputs[1]
        text_embeds = model.owlvit.text_projection(text_embeds)
        text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
        query_embeds = text_embeds
    return input_ids, query_embeds


def prompt2vec(prompt: str, model, processor):
    """Convert prompt into a computational vector

    Args:
        prompt (str): Text to be tokenized

    Returns:
        xq: vector from the tokenizer, representing the original prompt
    """
    # inputs = tokenizer(prompt, return_tensors='pt')
    # out = clip.get_text_features(**inputs)
    input_ids, xq = extract_text_feature(prompt, model, processor)
    input_ids = input_ids.detach().cpu().numpy()
    xq = xq.detach().cpu().numpy()
    return input_ids, xq


def tune(clf, X, y, iters=2):
    """Train the Zero-shot Classifier

    Args:
        X (numpy.ndarray): Input vectors (retreived vectors)
        y (list of floats or numpy.ndarray): Scores given by user
        iters (int, optional): iterations of updates to be run
    """
    assert len(X) == len(y)
    # train the classifier
    clf.fit(X, y, iters=iters)
    # extract new vector
    return clf.get_weights()

class Classifier:
    """Multi-Class Zero-shot Classifier
    This Classifier provides proxy regarding to the user's reaction to the probed images.
    The proxy will replace the original query vector generated by prompted vector and finally
    give the user a satisfying retrieval result.

    This can be commonly seen in a recommendation system. The classifier will recommend more
    precise result as it accumulating user's activity.

    This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
    and the last one takes the negative one.
    """

    def __init__(self, client, obj_db:str, xq: list):
        init_weight = torch.Tensor(xq)
        self.num_class = xq.shape[0]
        self.DIMS = xq.shape[1]
        # convert initial query `xq` to tensor parameter to init weights
        self.weight = init_weight
        self.client = client
        self.obj_db = obj_db

    def fit(self, X: list, y: list, iters: int = 5):
        # convert X and y to tensor
        xq_s = [
            f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]"
            for _xq in self.get_weights().tolist()
        ]
        
        for _ in range(iters):
            # zero gradients
            grad = []
            # Normalize the weight before inference
            # This will constrain the gradient or you will have an explosion on query vector
            self.weight.data /= torch.norm(
                self.weight.data, p=2, dim=-1, keepdim=True
            )
            for n in range(self.num_class):
                # select all training sample and create labels
                labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]])))
                
                # NOTE from @fangruil
                # Use SQL to calculate the gradient
                # For binary cross entropy we have 
                #   g = (1/(1+\exp(-XW))-Y)^TX
                # To simplify the query, we separated 
                # the calculation into class numbers
                grad_q_str = f"""
                    SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
                    FROM (
                        SELECT groupArray(arrayPopBack(prelogit)) AS X, 
                            groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
                        FROM {self.obj_db} WHERE obj_id IN {objs})"""
                grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
            # update weights
            grad = torch.stack(grad, dim=0)
            self.weight -= 0.1 * grad

    def get_weights(self):
        xq = self.weight.detach().numpy()
        return xq



class SplitLayer(torch.nn.Module):
    def forward(self, x):
        return torch.split(x, 1, dim=-1)