ChenWu98 commited on
Commit
334ea23
1 Parent(s): ed1066d

Create seq_aligner.py

Browse files
Files changed (1) hide show
  1. seq_aligner.py +196 -0
seq_aligner.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import numpy as np
16
+
17
+
18
+ class ScoreParams:
19
+
20
+ def __init__(self, gap, match, mismatch):
21
+ self.gap = gap
22
+ self.match = match
23
+ self.mismatch = mismatch
24
+
25
+ def mis_match_char(self, x, y):
26
+ if x != y:
27
+ return self.mismatch
28
+ else:
29
+ return self.match
30
+
31
+
32
+ def get_matrix(size_x, size_y, gap):
33
+ matrix = []
34
+ for i in range(len(size_x) + 1):
35
+ sub_matrix = []
36
+ for j in range(len(size_y) + 1):
37
+ sub_matrix.append(0)
38
+ matrix.append(sub_matrix)
39
+ for j in range(1, len(size_y) + 1):
40
+ matrix[0][j] = j*gap
41
+ for i in range(1, len(size_x) + 1):
42
+ matrix[i][0] = i*gap
43
+ return matrix
44
+
45
+
46
+ def get_matrix(size_x, size_y, gap):
47
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
48
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
49
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
50
+ return matrix
51
+
52
+
53
+ def get_traceback_matrix(size_x, size_y):
54
+ matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
55
+ matrix[0, 1:] = 1
56
+ matrix[1:, 0] = 2
57
+ matrix[0, 0] = 4
58
+ return matrix
59
+
60
+
61
+ def global_align(x, y, score):
62
+ matrix = get_matrix(len(x), len(y), score.gap)
63
+ trace_back = get_traceback_matrix(len(x), len(y))
64
+ for i in range(1, len(x) + 1):
65
+ for j in range(1, len(y) + 1):
66
+ left = matrix[i, j - 1] + score.gap
67
+ up = matrix[i - 1, j] + score.gap
68
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
69
+ matrix[i, j] = max(left, up, diag)
70
+ if matrix[i, j] == left:
71
+ trace_back[i, j] = 1
72
+ elif matrix[i, j] == up:
73
+ trace_back[i, j] = 2
74
+ else:
75
+ trace_back[i, j] = 3
76
+ return matrix, trace_back
77
+
78
+
79
+ def get_aligned_sequences(x, y, trace_back):
80
+ x_seq = []
81
+ y_seq = []
82
+ i = len(x)
83
+ j = len(y)
84
+ mapper_y_to_x = []
85
+ while i > 0 or j > 0:
86
+ if trace_back[i, j] == 3:
87
+ x_seq.append(x[i-1])
88
+ y_seq.append(y[j-1])
89
+ i = i-1
90
+ j = j-1
91
+ mapper_y_to_x.append((j, i))
92
+ elif trace_back[i][j] == 1:
93
+ x_seq.append('-')
94
+ y_seq.append(y[j-1])
95
+ j = j-1
96
+ mapper_y_to_x.append((j, -1))
97
+ elif trace_back[i][j] == 2:
98
+ x_seq.append(x[i-1])
99
+ y_seq.append('-')
100
+ i = i-1
101
+ elif trace_back[i][j] == 4:
102
+ break
103
+ mapper_y_to_x.reverse()
104
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
105
+
106
+
107
+ def get_mapper(x: str, y: str, tokenizer, max_len=77):
108
+ x_seq = tokenizer.encode(x)
109
+ y_seq = tokenizer.encode(y)
110
+ score = ScoreParams(0, 1, -1)
111
+ matrix, trace_back = global_align(x_seq, y_seq, score)
112
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
113
+ alphas = torch.ones(max_len)
114
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
115
+ mapper = torch.zeros(max_len, dtype=torch.int64)
116
+ mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
117
+ mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
118
+ return mapper, alphas
119
+
120
+
121
+ def get_refinement_mapper(prompts, tokenizer, max_len=77):
122
+ x_seq = prompts[0]
123
+ mappers, alphas = [], []
124
+ for i in range(1, len(prompts)):
125
+ mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
126
+ mappers.append(mapper)
127
+ alphas.append(alpha)
128
+ return torch.stack(mappers), torch.stack(alphas)
129
+
130
+
131
+ def get_word_inds(text: str, word_place: int, tokenizer):
132
+ split_text = text.split(" ")
133
+ if type(word_place) is str:
134
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
135
+ elif type(word_place) is int:
136
+ word_place = [word_place]
137
+ out = []
138
+ if len(word_place) > 0:
139
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
140
+ cur_len, ptr = 0, 0
141
+
142
+ for i in range(len(words_encode)):
143
+ cur_len += len(words_encode[i])
144
+ if ptr in word_place:
145
+ out.append(i + 1)
146
+ if cur_len >= len(split_text[ptr]):
147
+ ptr += 1
148
+ cur_len = 0
149
+ return np.array(out)
150
+
151
+
152
+ def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
153
+ words_x = x.split(' ')
154
+ words_y = y.split(' ')
155
+ if len(words_x) != len(words_y):
156
+ raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
157
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
158
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
159
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
160
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
161
+ mapper = np.zeros((max_len, max_len))
162
+ i = j = 0
163
+ cur_inds = 0
164
+ while i < max_len and j < max_len:
165
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
166
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
167
+ if len(inds_source_) == len(inds_target_):
168
+ mapper[inds_source_, inds_target_] = 1
169
+ else:
170
+ ratio = 1 / len(inds_target_)
171
+ for i_t in inds_target_:
172
+ mapper[inds_source_, i_t] = ratio
173
+ cur_inds += 1
174
+ i += len(inds_source_)
175
+ j += len(inds_target_)
176
+ elif cur_inds < len(inds_source):
177
+ mapper[i, j] = 1
178
+ i += 1
179
+ j += 1
180
+ else:
181
+ mapper[j, j] = 1
182
+ i += 1
183
+ j += 1
184
+
185
+ return torch.from_numpy(mapper).float()
186
+
187
+
188
+
189
+ def get_replacement_mapper(prompts, tokenizer, max_len=77):
190
+ x_seq = prompts[0]
191
+ mappers = []
192
+ for i in range(1, len(prompts)):
193
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
194
+ mappers.append(mapper)
195
+ return torch.stack(mappers)
196
+