Daniel Gil-U Fuhge commited on
Commit
d9c6096
·
1 Parent(s): 076948a

add dataset helper

Browse files
Files changed (1) hide show
  1. dataset_helper.py +326 -0
dataset_helper.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Tuple, Any
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+
8
+ # SEQUENCE GENERATION
9
+ PADDING_VALUE = float('-100')
10
+
11
+ # ANIMATION_PARAMETER_INDICES = {
12
+ # 0: [], # EOS
13
+ # 1: [10, 11, 12, 13], # translate: begin, dur, x, y
14
+ # 2: [10, 11, 14, 15], # curve: begin, dur, via_x, via_y
15
+ # 3: [10, 11, 16], # scale: begin, dur, from_factor
16
+ # 4: [10, 11, 17], # rotate: begin, dur, from_degree
17
+ # 5: [10, 11, 18], # skewX: begin, dur, from_x
18
+ # 6: [10, 11, 19], # skewY: begin, dur, from_y
19
+ # 7: [10, 11, 20, 21, 22], # fill: begin, dur, from_r, from_g, from_b
20
+ # 8: [10, 11, 23], # opcaity: begin, dur, from_f
21
+ # 9: [10, 11, 24], # blur: begin, dur, from_f
22
+ # }
23
+
24
+ ANIMATION_PARAMETER_INDICES = {
25
+ 0: [], # EOS
26
+ 1: [0, 1, 2, 3], # translate: begin, dur, x, y
27
+ 2: [0, 1, 4, 5], # curve: begin, dur, via_x, via_y
28
+ 3: [0, 1, 6], # scale: begin, dur, from_factor
29
+ 4: [0, 1, 7], # rotate: begin, dur, from_degree
30
+ 5: [0, 1, 8], # skewX: begin, dur, from_x
31
+ 6: [0, 1, 9], # skewY: begin, dur, from_y
32
+ 7: [0, 1, 10, 11, 12], # fill: begin, dur, from_r, from_g, from_b
33
+ 8: [0, 1, 13], # opcaity: begin, dur, from_f
34
+ 9: [0, 1, 14], # blur: begin, dur, from_f
35
+ }
36
+
37
+
38
+ def unpack_embedding(embedding: torch.Tensor, dim=0, device="cpu") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
39
+ """
40
+ Args:
41
+ device: cpu / gpu
42
+ dim: dimension where the embedding is positioned
43
+ embedding: embedding of dimension 270
44
+
45
+ Returns: tuple of tensors: deep-svg embedding, type of prediction, animation parameters
46
+
47
+ """
48
+ if embedding.shape[dim] != 282:
49
+ print(embedding.shape)
50
+ raise ValueError('Dimension of 270 required.')
51
+
52
+ if dim == 0:
53
+ deep_svg = embedding[: -26].to(device)
54
+ types = embedding[-26: -15].to(device)
55
+ parameters = embedding[-15:].to(device)
56
+
57
+ elif dim == 1:
58
+ deep_svg = embedding[:, : -26].to(device)
59
+ types = embedding[:, -26: -15].to(device)
60
+ parameters = embedding[:, -15:].to(device)
61
+
62
+ elif dim == 2:
63
+ deep_svg = embedding[:, :, : -26].to(device)
64
+ types = embedding[:, :, -26: -15].to(device)
65
+ parameters = embedding[:, :, -15:].to(device)
66
+
67
+ else:
68
+ raise ValueError('Dimension > 2 not possible.')
69
+ return deep_svg, types, parameters
70
+
71
+
72
+ def generate_dataset(dataframe_index: pd.DataFrame,
73
+ input_sequences_dict_used: dict,
74
+ input_sequences_dict_unused: dict,
75
+ output_sequences: pd.DataFrame,
76
+ logos_list: dict,
77
+ sequence_length_input: int,
78
+ sequence_length_output: int,
79
+ ) -> dict:
80
+ """
81
+ Builds the dataset and returns it
82
+
83
+ Args:
84
+ input_sequences_dict_used: dictionary containing input sequences per logo
85
+ input_sequences_dict_unused: dictionary containing all unused paths
86
+ dataframe_index: dataframe containing the relevant indexes for the dataframes
87
+ output_sequences: dataframe containing animations
88
+ logos_list: dictionary in train/test split containing list for logo ids
89
+ sequence_length_input: length of input sequence for padding
90
+ sequence_length_output: length of output sequence for padding
91
+
92
+ Returns: dictionary containing the dataset for training/testing
93
+
94
+ """
95
+ dataset = {
96
+ "is_bucketing": False,
97
+ "train": {
98
+ "input": [],
99
+ "output": []
100
+ },
101
+ "test": {
102
+ "input": [],
103
+ "output": []
104
+ }
105
+ }
106
+ for i, logo_info in dataframe_index.iterrows():
107
+ logo = logo_info['filename'] # e.g. logo_1
108
+ file = logo_info['file'] # e.g. logo_1_animation_2
109
+ oversample = logo_info['repeat']
110
+ print(f"Processing {logo} with {file}: ")
111
+
112
+ if input_sequences_dict_used.keys().__contains__(logo) and input_sequences_dict_unused.keys().__contains__(logo):
113
+ for j in range(oversample):
114
+ input_tensor = _generate_input_sequence(
115
+ input_sequences_dict_used[logo].copy(),
116
+ input_sequences_dict_unused[logo].copy(),
117
+ #pd.DataFrame(),
118
+ null_features=26, # TODO depends on architecture later
119
+ sequence_length=sequence_length_input,
120
+ # is_randomized=True, always now
121
+ is_padding=True
122
+ )
123
+
124
+ output_tensor = _generate_output_sequence(
125
+ output_sequences[(output_sequences['filename'] == logo) & (output_sequences['file'] == file)].copy(),
126
+ sequence_length=sequence_length_output,
127
+ is_randomized=False,
128
+ is_padding=True
129
+ )
130
+ # append to lists
131
+ if logo in logos_list["train"]:
132
+ random_index = random.randint(0, len(dataset["train"]["input"]))
133
+ dataset["train"]["input"].insert(random_index, input_tensor)
134
+ dataset["train"]["output"].insert(random_index, output_tensor)
135
+
136
+ elif logo in logos_list["test"]:
137
+ dataset["test"]["input"].append(input_tensor)
138
+ dataset["test"]["output"].append(output_tensor)
139
+ break # no oversampling in testing
140
+
141
+ else:
142
+ print(f"Some problem with {logo}. Neither in train or test set list.")
143
+ break
144
+
145
+ dataset["train"]["input"] = torch.stack(dataset["train"]["input"])
146
+ dataset["train"]["output"] = torch.stack(dataset["train"]["output"])
147
+ dataset["test"]["input"] = torch.stack(dataset["test"]["input"])
148
+ dataset["test"]["output"] = torch.stack(dataset["test"]["output"])
149
+
150
+ return dataset
151
+
152
+
153
+ def _generate_input_sequence(logo_embeddings_used: pd.DataFrame,
154
+ logo_embeddings_unused: pd.DataFrame,
155
+ null_features: int,
156
+ sequence_length: int,
157
+ is_padding: bool) -> torch.Tensor:
158
+ """
159
+ Build a torch tensor for the transformer input sequences.
160
+ Includes
161
+ - Ensuring all used embeddings are included
162
+ - Filling the remainder with unused embeddings up to sequence length
163
+ - Generation of padding
164
+
165
+ Args:
166
+ logo_embeddings (pd.DataFrame): DataFrame containing logo embeddings.
167
+ null_features (int): Number of null features to add to each embedding.
168
+ sequence_length (int): Target length for padding sequences.
169
+ is_padding: if true, function adds padding
170
+
171
+ Returns:
172
+ torch.Tensor: Tensor representing the input sequences.
173
+ """
174
+ logo_embeddings_used.drop(columns=['filename', 'animation_id'], inplace=True)
175
+ logo_embeddings_unused.drop(columns=['filename', 'animation_id'], inplace=True)
176
+
177
+ # Combine used and unused. Fill used with random unused samples
178
+ logo_embeddings = logo_embeddings_unused
179
+ remaining_slots = sequence_length - len(logo_embeddings)
180
+ if remaining_slots > 0:
181
+ sample_size = min(len(logo_embeddings_unused), remaining_slots)
182
+ additional_embeddings = logo_embeddings_unused.sample(n=sample_size, replace=False)
183
+ logo_embeddings = pd.concat([logo_embeddings, additional_embeddings], ignore_index=True)
184
+ logo_embeddings.reset_index()
185
+
186
+ # Randomization
187
+ logo_embeddings = logo_embeddings.sample(frac=1).reset_index(drop=True)
188
+
189
+ # Null Features
190
+ if null_features > 0:
191
+ logo_embeddings = pd.concat([logo_embeddings,
192
+ pd.DataFrame(0,
193
+ index=logo_embeddings.index,
194
+ columns=range(logo_embeddings.shape[1],
195
+ logo_embeddings.shape[1] + null_features))],
196
+ axis=1,
197
+ ignore_index=True)
198
+
199
+ if is_padding:
200
+ logo_embeddings = _add_padding(logo_embeddings, sequence_length)
201
+
202
+ return torch.tensor(logo_embeddings.values)
203
+
204
+
205
+ def _generate_output_sequence(animation: pd.DataFrame,
206
+ sequence_length: int,
207
+ is_randomized: bool,
208
+ is_padding: bool) -> torch.Tensor:
209
+ """
210
+ Build a torch tensor for the transformer output sequences.
211
+ Includes
212
+ - Randomization (later, when same start time)
213
+ - Generation of padding
214
+ - Add EOS Token
215
+
216
+ Args:
217
+ animation (pd.DataFrame): DataFrame containing logo embeddings.
218
+ sequence_length (int): Target length for padding sequences.
219
+ is_randomized: shuffle order of paths, applies when same start time
220
+ is_padding: if true, function adds padding
221
+
222
+ Returns:
223
+ torch.Tensor: Tensor representing the input sequences.
224
+ """
225
+ if is_randomized:
226
+ animation = animation.sample(frac=1).reset_index(drop=True)
227
+ print("Note: Randomization not implemented yet")
228
+
229
+ animation.sort_values(by=['a10'], inplace=True) # again ordered by time start.
230
+ animation.drop(columns=['file', 'filename', "Unnamed: 0", "id"], inplace=True)
231
+
232
+ # Append the EOS row to the DataFrame
233
+ sos_eos_row = {col: 0 for col in animation.columns}
234
+ sos_eos_row["a0"] = 1
235
+ sos_eos_row = pd.DataFrame([sos_eos_row])
236
+ animation = pd.concat([sos_eos_row, animation, sos_eos_row],
237
+ ignore_index=True)
238
+
239
+ # Padding Generation: Add padding rows or cut off excess rows
240
+ if is_padding:
241
+ animation = _add_padding(animation, sequence_length)
242
+
243
+ return torch.Tensor(animation.values)
244
+
245
+
246
+ def _add_padding(dataframe: pd.DataFrame, sequence_length: int) -> pd.DataFrame:
247
+ """
248
+ Add padding to a dataframe
249
+
250
+ Args:
251
+ dataframe: dataframe to add padding to
252
+ sequence_length: length of final sequences
253
+
254
+ Returns:
255
+
256
+ """
257
+ if len(dataframe) < sequence_length:
258
+ padding_rows = pd.DataFrame([[PADDING_VALUE] * len(dataframe.columns)] * (sequence_length - len(dataframe)),
259
+ columns=dataframe.columns)
260
+ dataframe = pd.concat([dataframe, padding_rows], ignore_index=True)
261
+ elif len(dataframe) > sequence_length:
262
+ # Cut off excess rows
263
+ dataframe = dataframe.iloc[:sequence_length]
264
+
265
+ return dataframe
266
+
267
+
268
+ # BUCKETING
269
+ def generate_buckets_2D(dataset, column1, column2, quantiles1, quantiles2, print_histogram=True):
270
+ """
271
+
272
+ Args:
273
+ dataset: dataset to generate buckets for
274
+ column1: first column name
275
+ column2: second column name
276
+ quantiles1: initial quantiles for column1
277
+ quantiles2: initial quantiles for column2
278
+ print_histogram: if true, a histogram of the 2D buckets is printed
279
+
280
+ Returns: dictionary object with bucket edges
281
+
282
+ """
283
+ x_edges = dataset[column1].quantile(quantiles1)
284
+ y_edges = dataset[column2].quantile(quantiles2)
285
+
286
+ x_edges = np.array(x_edges)
287
+ y_edges = np.unique(y_edges)
288
+
289
+ if print_histogram:
290
+ hist, x_edges, y_edges = np.histogram2d(dataset[column1],
291
+ dataset[column2],
292
+ bins=[x_edges, y_edges])
293
+ print(hist)
294
+
295
+ return {
296
+ "input_edges": list(x_edges),
297
+ "output_edges": list(y_edges)
298
+ }
299
+
300
+
301
+ def get_bucket(input_length, output_length, buckets):
302
+ bucket_name = ""
303
+
304
+ for i, input_edge in enumerate(buckets["input_edges"]):
305
+ # print(f"{i}: {input_length} < {input_edge}")
306
+ if input_length > input_edge:
307
+ continue
308
+
309
+ bucket_name = bucket_name + str(int(i)) # chr(ord('A')+i)
310
+ break
311
+
312
+ bucket_name = bucket_name + "-"
313
+
314
+ for i, output_edge in enumerate(buckets["output_edges"]):
315
+ if output_length > output_edge:
316
+ continue
317
+
318
+ bucket_name = bucket_name + str(int(i))
319
+ break
320
+
321
+ return bucket_name
322
+
323
+
324
+ def warn_if_contains_NaN(dataset: torch.Tensor):
325
+ if torch.isnan(dataset).any():
326
+ print("There are NaN values in the dataset")