Spaces:
Build error
Build error
adymaharana
commited on
Commit
·
77e955b
1
Parent(s):
da4893d
cuda device
Browse files- app.py +2 -1
- dalle/models/__init__.py +2 -2
app.py
CHANGED
@@ -66,7 +66,8 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
|
|
66 |
|
67 |
|
68 |
def main(args):
|
69 |
-
device = 'cuda:0'
|
|
|
70 |
|
71 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
72 |
|
|
|
66 |
|
67 |
|
68 |
def main(args):
|
69 |
+
#device = 'cuda:0'
|
70 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
71 |
|
72 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
73 |
|
dalle/models/__init__.py
CHANGED
@@ -1193,9 +1193,9 @@ class StoryDalle(Dalle):
|
|
1193 |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
1194 |
# model.from_ckpt(args.model_name_or_path)
|
1195 |
try:
|
1196 |
-
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
|
1197 |
except KeyError:
|
1198 |
-
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
|
1199 |
else:
|
1200 |
model = cls(config_update)
|
1201 |
print(model.cross_attention_idxs)
|
|
|
1193 |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
1194 |
# model.from_ckpt(args.model_name_or_path)
|
1195 |
try:
|
1196 |
+
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
|
1197 |
except KeyError:
|
1198 |
+
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['model_state_dict'])
|
1199 |
else:
|
1200 |
model = cls(config_update)
|
1201 |
print(model.cross_attention_idxs)
|