File size: 3,311 Bytes
3d2142b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cee877
3d2142b
 
 
 
4cee877
3d2142b
 
 
 
 
 
 
 
 
 
 
 
 
 
825a49c
3d2142b
 
 
825a49c
3d2142b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# ------------------------------------------------------------------------
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Concet projector."""

import pickle

import numpy as np
import torch
from torch import nn


class ConceptProjector(nn.Module):
    """Encode and decode concept using CLIP."""

    def __init__(self, src_weights=None, tgt_weights=None):
        super(ConceptProjector, self).__init__()
        self.reset_weights(src_weights, tgt_weights)

    def reset_weights(self, src_weights=None, tgt_weights=None):
        """Reset the normalized projection weights."""
        if src_weights:
            with open(src_weights, "rb") as f:
                self.src_weights, self.concepts = pickle.load(f)
                self.src_weights = torch.from_numpy(self.src_weights)
                self.concepts = np.array(self.concepts)
        if tgt_weights:
            with open(tgt_weights, "rb") as f:
                self.tgt_weights, self.concepts = pickle.load(f)
                self.tgt_weights = torch.from_numpy(self.tgt_weights)
                self.concepts = np.array(self.concepts)

    @staticmethod
    def maybe_convert(embeds, proj):
        """Convert inputs for safe projection."""
        if embeds.dtype != torch.float32:
            embeds = embeds.float()
        if embeds.device != proj.device:
            proj = proj.to(device=embeds.device)
        return embeds, proj

    def encode_src(self, src_embeds, logpi=True):
        """Encode source visual embedding via concept projection."""
        src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
        logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
        return nn.functional.log_softmax(logits, dim=-1) if logpi else logits

    def encode_tgt(self, tgt_embeds):
        """Encode target visual embedding via concept projection."""
        tgt_embeds, self.tgt_weights = self.maybe_convert(tgt_embeds, self.tgt_weights)
        logits = nn.functional.normalize(tgt_embeds, dim=-1) @ self.tgt_weights
        return nn.functional.log_softmax(logits, dim=-1)

    def decode(self, src_embeds, k=1, return_index=False, return_prob=False):
        """Return the top-k concepts of source visual embedding."""
        src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
        logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
        probs = nn.functional.softmax(logits, dim=-1)
        if return_prob:
            return probs.cpu().numpy()
        score, index = [x.cpu().numpy() for x in probs.topk(k, -1)]
        return (index if return_index else self.concepts[index]), score