updates for GPU compatibility
Browse files- .DS_Store +0 -0
- ContraCLIP/lib/aux.py +17 -6
- ContraCLIP/models/load_generator.py +4 -1
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
ContraCLIP/lib/aux.py
CHANGED
@@ -68,16 +68,27 @@ def create_exp_dir(args):
|
|
68 |
class PromptFeatures:
|
69 |
def __init__(self, prompt_corpus, clip_model):
|
70 |
self.prompt_corpus = prompt_corpus
|
71 |
-
self.clip_model = clip_model.cpu()
|
|
|
72 |
self.num_prompts = len(self.prompt_corpus)
|
73 |
self.prompt_features_dim = 512
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def get_prompt_features(self):
|
76 |
-
|
77 |
-
self.clip_model.
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
class TrainingStatTracker(object):
|
|
|
68 |
class PromptFeatures:
|
69 |
def __init__(self, prompt_corpus, clip_model):
|
70 |
self.prompt_corpus = prompt_corpus
|
71 |
+
# self.clip_model = clip_model.cpu()
|
72 |
+
self.clip_model = clip_model
|
73 |
self.num_prompts = len(self.prompt_corpus)
|
74 |
self.prompt_features_dim = 512
|
75 |
|
76 |
+
# def get_prompt_features(self):
|
77 |
+
# prompt_features = [
|
78 |
+
# self.clip_model.encode_text(clip.tokenize(self.prompt_corpus[t]).cpu()).unsqueeze(0) for t in
|
79 |
+
# range(len(self.prompt_corpus))
|
80 |
+
# ]
|
81 |
+
# return torch.cat(prompt_features, dim=0)
|
82 |
def get_prompt_features(self):
|
83 |
+
# Get the device of the CLIP model
|
84 |
+
device = next(self.clip_model.parameters()).device
|
85 |
+
|
86 |
+
# Move tokenized text to the same device as the model
|
87 |
+
prompt_features = [
|
88 |
+
self.clip_model.encode_text(clip.tokenize(self.prompt_corpus[t]).to(device)).unsqueeze(0)
|
89 |
+
for t in range(len(self.prompt_corpus))
|
90 |
+
]
|
91 |
+
return torch.cat(prompt_features, dim=0)
|
92 |
|
93 |
|
94 |
class TrainingStatTracker(object):
|
ContraCLIP/models/load_generator.py
CHANGED
@@ -30,7 +30,10 @@ def load_generator(model_name, latent_is_w=False, verbose=False, CHECKPOINT_DIR=
|
|
30 |
|
31 |
if not osp.exists(checkpoint_path):
|
32 |
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
33 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
|
|
34 |
if 'generator_smooth' in checkpoint:
|
35 |
generator.load_state_dict(checkpoint['generator_smooth'])
|
36 |
else:
|
|
|
30 |
|
31 |
if not osp.exists(checkpoint_path):
|
32 |
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
33 |
+
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
34 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
35 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
36 |
+
|
37 |
if 'generator_smooth' in checkpoint:
|
38 |
generator.load_state_dict(checkpoint['generator_smooth'])
|
39 |
else:
|