diff --git a/pipeline.py b/pipeline.py index 7c41e04..842c5b4 100644 --- a/pipeline.py +++ b/pipeline.py @@ -27,7 +27,7 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10): for exp_name in os.listdir(root_folder): # get everything up to the first numeric pure_names.append(exp_name) - encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin")) + encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin"), map_location=pipe.device) embeddings = encoder["text_model.embeddings.token_embedding.weight"] param_value = embeddings[-10:] @@ -36,23 +36,23 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10): string_name = f"<{exp_name}|{t}|>" tokens_to_add.append(string_name) string_to_param_dict[string_name] = torch.nn.Parameter(param_value[t].unsqueeze(0).repeat([num_scales, 1])) - + # Fully Resolution: use appropriate time embedding for the whole generation time. string_name = f"<{exp_name}[{t}]>" tokens_to_add.append(string_name) repeats = t + 1 rep_param = param_value[t].unsqueeze(0).repeat([repeats, 1]) left = param_value[rep_param.shape[0]:] - new_param = torch.cat([rep_param, left]) + new_param = torch.cat([rep_param, left]) string_to_param_dict[string_name] = torch.nn.Parameter(new_param) # Semi Resolution: use appropriate time embedding up to a certain time and then no conditioning. string_name = f"<{exp_name}({t})>" tokens_to_add.append(string_name) - null_embedding = torch.zeros((param_value.shape[1],), device=param_value.device, dtype=param_value.dtype) + null_embedding = torch.zeros((param_value.shape[1],), device=pipe.device, dtype=param_value.dtype) rep_param = null_embedding.unsqueeze(0).repeat([t + 1, 1]) left = param_value[rep_param.shape[0]:] - new_param = torch.cat([rep_param, left]) + new_param = torch.cat([rep_param, left]) string_to_param_dict[string_name] = torch.nn.Parameter(new_param) pipe.tokenizer.add_tokens(tokens_to_add)