dattarij commited on
Commit
b02a0ab
·
1 Parent(s): 8c212a5

updates for GPU compatibility

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
ContraCLIP/lib/aux.py CHANGED
@@ -68,16 +68,27 @@ def create_exp_dir(args):
68
  class PromptFeatures:
69
  def __init__(self, prompt_corpus, clip_model):
70
  self.prompt_corpus = prompt_corpus
71
- self.clip_model = clip_model.cpu()
 
72
  self.num_prompts = len(self.prompt_corpus)
73
  self.prompt_features_dim = 512
74
 
 
 
 
 
 
 
75
  def get_prompt_features(self):
76
- prompt_features = [
77
- self.clip_model.encode_text(clip.tokenize(self.prompt_corpus[t]).cpu()).unsqueeze(0) for t in
78
- range(len(self.prompt_corpus))
79
- ]
80
- return torch.cat(prompt_features, dim=0)
 
 
 
 
81
 
82
 
83
  class TrainingStatTracker(object):
 
68
  class PromptFeatures:
69
  def __init__(self, prompt_corpus, clip_model):
70
  self.prompt_corpus = prompt_corpus
71
+ # self.clip_model = clip_model.cpu()
72
+ self.clip_model = clip_model
73
  self.num_prompts = len(self.prompt_corpus)
74
  self.prompt_features_dim = 512
75
 
76
+ # def get_prompt_features(self):
77
+ # prompt_features = [
78
+ # self.clip_model.encode_text(clip.tokenize(self.prompt_corpus[t]).cpu()).unsqueeze(0) for t in
79
+ # range(len(self.prompt_corpus))
80
+ # ]
81
+ # return torch.cat(prompt_features, dim=0)
82
  def get_prompt_features(self):
83
+ # Get the device of the CLIP model
84
+ device = next(self.clip_model.parameters()).device
85
+
86
+ # Move tokenized text to the same device as the model
87
+ prompt_features = [
88
+ self.clip_model.encode_text(clip.tokenize(self.prompt_corpus[t]).to(device)).unsqueeze(0)
89
+ for t in range(len(self.prompt_corpus))
90
+ ]
91
+ return torch.cat(prompt_features, dim=0)
92
 
93
 
94
  class TrainingStatTracker(object):
ContraCLIP/models/load_generator.py CHANGED
@@ -30,7 +30,10 @@ def load_generator(model_name, latent_is_w=False, verbose=False, CHECKPOINT_DIR=
30
 
31
  if not osp.exists(checkpoint_path):
32
  subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
33
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
 
 
 
34
  if 'generator_smooth' in checkpoint:
35
  generator.load_state_dict(checkpoint['generator_smooth'])
36
  else:
 
30
 
31
  if not osp.exists(checkpoint_path):
32
  subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
33
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ checkpoint = torch.load(checkpoint_path, map_location=device)
36
+
37
  if 'generator_smooth' in checkpoint:
38
  generator.load_state_dict(checkpoint['generator_smooth'])
39
  else: