dbaranchuk commited on
Commit
fac3c6f
·
verified ·
1 Parent(s): 8f3a280

Delete seq_aligner.py

Browse files
Files changed (1) hide show
  1. seq_aligner.py +0 -181
seq_aligner.py DELETED
@@ -1,181 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- class ScoreParams:
6
-
7
- def __init__(self, gap, match, mismatch):
8
- self.gap = gap
9
- self.match = match
10
- self.mismatch = mismatch
11
-
12
- def mis_match_char(self, x, y):
13
- if x != y:
14
- return self.mismatch
15
- else:
16
- return self.match
17
-
18
-
19
- def get_matrix(size_x, size_y, gap):
20
- matrix = []
21
- for i in range(len(size_x) + 1):
22
- sub_matrix = []
23
- for j in range(len(size_y) + 1):
24
- sub_matrix.append(0)
25
- matrix.append(sub_matrix)
26
- for j in range(1, len(size_y) + 1):
27
- matrix[0][j] = j * gap
28
- for i in range(1, len(size_x) + 1):
29
- matrix[i][0] = i * gap
30
- return matrix
31
-
32
-
33
- def get_matrix(size_x, size_y, gap):
34
- matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
35
- matrix[0, 1:] = (np.arange(size_y) + 1) * gap
36
- matrix[1:, 0] = (np.arange(size_x) + 1) * gap
37
- return matrix
38
-
39
-
40
- def get_traceback_matrix(size_x, size_y):
41
- matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
42
- matrix[0, 1:] = 1
43
- matrix[1:, 0] = 2
44
- matrix[0, 0] = 4
45
- return matrix
46
-
47
-
48
- def global_align(x, y, score):
49
- matrix = get_matrix(len(x), len(y), score.gap)
50
- trace_back = get_traceback_matrix(len(x), len(y))
51
- for i in range(1, len(x) + 1):
52
- for j in range(1, len(y) + 1):
53
- left = matrix[i, j - 1] + score.gap
54
- up = matrix[i - 1, j] + score.gap
55
- diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
56
- matrix[i, j] = max(left, up, diag)
57
- if matrix[i, j] == left:
58
- trace_back[i, j] = 1
59
- elif matrix[i, j] == up:
60
- trace_back[i, j] = 2
61
- else:
62
- trace_back[i, j] = 3
63
- return matrix, trace_back
64
-
65
-
66
- def get_aligned_sequences(x, y, trace_back):
67
- x_seq = []
68
- y_seq = []
69
- i = len(x)
70
- j = len(y)
71
- mapper_y_to_x = []
72
- while i > 0 or j > 0:
73
- if trace_back[i, j] == 3:
74
- x_seq.append(x[i - 1])
75
- y_seq.append(y[j - 1])
76
- i = i - 1
77
- j = j - 1
78
- mapper_y_to_x.append((j, i))
79
- elif trace_back[i][j] == 1:
80
- x_seq.append('-')
81
- y_seq.append(y[j - 1])
82
- j = j - 1
83
- mapper_y_to_x.append((j, -1))
84
- elif trace_back[i][j] == 2:
85
- x_seq.append(x[i - 1])
86
- y_seq.append('-')
87
- i = i - 1
88
- elif trace_back[i][j] == 4:
89
- break
90
- mapper_y_to_x.reverse()
91
- return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
92
-
93
-
94
- def get_mapper(x: str, y: str, tokenizer, max_len=77):
95
- x_seq = tokenizer.encode(x)
96
- y_seq = tokenizer.encode(y)
97
- score = ScoreParams(0, 1, -1)
98
- matrix, trace_back = global_align(x_seq, y_seq, score)
99
- mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
100
- alphas = torch.ones(max_len)
101
- alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
102
- mapper = torch.zeros(max_len, dtype=torch.int64)
103
- mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
104
- mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
105
- return mapper, alphas
106
-
107
-
108
- def get_refinement_mapper(prompts, tokenizer, max_len=77):
109
- x_seq = prompts[0]
110
- mappers, alphas = [], []
111
- for i in range(1, len(prompts)):
112
- mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
113
- mappers.append(mapper)
114
- alphas.append(alpha)
115
- return torch.stack(mappers), torch.stack(alphas)
116
-
117
-
118
- def get_word_inds(text: str, word_place: int, tokenizer):
119
- split_text = text.split(" ")
120
- if type(word_place) is str:
121
- word_place = [i for i, word in enumerate(split_text) if word_place == word]
122
- elif type(word_place) is int:
123
- word_place = [word_place]
124
- out = []
125
- if len(word_place) > 0:
126
- words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
127
- cur_len, ptr = 0, 0
128
-
129
- for i in range(len(words_encode)):
130
- cur_len += len(words_encode[i])
131
- if ptr in word_place:
132
- out.append(i + 1)
133
- if cur_len >= len(split_text[ptr]):
134
- ptr += 1
135
- cur_len = 0
136
- return np.array(out)
137
-
138
-
139
- def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
140
- words_x = x.split(' ')
141
- words_y = y.split(' ')
142
- if len(words_x) != len(words_y):
143
- raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
144
- f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
145
- inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
146
- inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
147
- inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
148
- mapper = np.zeros((max_len, max_len))
149
- i = j = 0
150
- cur_inds = 0
151
- while i < max_len and j < max_len:
152
- if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
153
- inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
154
- if len(inds_source_) == len(inds_target_):
155
- mapper[inds_source_, inds_target_] = 1
156
- else:
157
- ratio = 1 / len(inds_target_)
158
- for i_t in inds_target_:
159
- mapper[inds_source_, i_t] = ratio
160
- cur_inds += 1
161
- i += len(inds_source_)
162
- j += len(inds_target_)
163
- elif cur_inds < len(inds_source):
164
- mapper[i, j] = 1
165
- i += 1
166
- j += 1
167
- else:
168
- mapper[j, j] = 1
169
- i += 1
170
- j += 1
171
-
172
- return torch.from_numpy(mapper).float()
173
-
174
-
175
- def get_replacement_mapper(prompts, tokenizer, max_len=77):
176
- x_seq = prompts[0]
177
- mappers = []
178
- for i in range(1, len(prompts)):
179
- mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
180
- mappers.append(mapper)
181
- return torch.stack(mappers)