MatthiasC commited on
Commit
e65b549
·
1 Parent(s): 02ffa1d

Load complete ckpt from hub with secret

Browse files
Files changed (2) hide show
  1. app.py +1 -4
  2. server.py +40 -25
app.py CHANGED
@@ -9,11 +9,8 @@ from PIL import Image
9
 
10
  import requests
11
  import logging
12
- from huggingface_hub import hf_hub_download
13
 
14
- logging.info("Start downloading")
15
- hf_hub_download(repo_id="MatthiasC/dall-e-logo", filename="README.md")
16
- logging.info("End downloading")
17
  def start_server():
18
  os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 1")
19
 
 
9
 
10
  import requests
11
  import logging
 
12
 
13
+
 
 
14
  def start_server():
15
  os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 1")
16
 
server.py CHANGED
@@ -13,11 +13,20 @@ from PIL import Image
13
  #import clip
14
  from dalle.models import Dalle
15
  import logging
 
16
  from dalle.utils.utils import clip_score, download
17
 
18
  print("Loading models...")
19
  app = FastAPI()
20
 
 
 
 
 
 
 
 
 
21
 
22
  # url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
23
  # root = os.path.expanduser("~/.cache/minDALLE")
@@ -31,33 +40,39 @@ app = FastAPI()
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
34
- model.to(device=device)
35
 
 
 
36
  # -----------------------------------------------------------
37
- state_dict_ = torch.load('last.ckpt', map_location='cpu')
38
- vqgan_stage_dict = model.stage1.state_dict()
39
-
40
- for name, param in state_dict_['state_dict'].items():
41
- if name not in model.stage1.state_dict().keys():
42
- continue
43
- if isinstance(param, nn.parameter.Parameter):
44
- param = param.data
45
- vqgan_stage_dict[name].copy_(param)
46
-
47
- model.stage1.load_state_dict(vqgan_stage_dict)
48
- #---------------------------------------------------------
49
- state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
50
- dalle_stage_dict = model.stage2.state_dict()
51
-
52
- for name, param in state_dict_dalle['state_dict'].items():
53
- if name[6:] not in model.stage2.state_dict().keys():
54
- print(name)
55
- continue
56
- if isinstance(param, nn.parameter.Parameter):
57
- param = param.data
58
- dalle_stage_dict[name[6:]].copy_(param)
59
-
60
- model.stage2.load_state_dict(dalle_stage_dict)
 
 
 
 
61
 
62
  # model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
63
  # model_clip.to(device=device)
 
13
  #import clip
14
  from dalle.models import Dalle
15
  import logging
16
+ import streamlit as st
17
  from dalle.utils.utils import clip_score, download
18
 
19
  print("Loading models...")
20
  app = FastAPI()
21
 
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ logging.info("Start downloading")
25
+ full_dict_path = hf_hub_download(repo_id="MatthiasC/dall-e-logo", filename="full_dict_new.ckpt",
26
+ use_auth_token=st.secrets["model_hub"])
27
+ logging.info("End downloading")
28
+ logging.info(full_dict_path)
29
+
30
 
31
  # url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
32
  # root = os.path.expanduser("~/.cache/minDALLE")
 
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
43
+ #model.to(device=device)
44
 
45
+
46
+ # OLD CODE
47
  # -----------------------------------------------------------
48
+ # state_dict_ = torch.load('last.ckpt', map_location='cpu')
49
+ # vqgan_stage_dict = model.stage1.state_dict()
50
+ #
51
+ # for name, param in state_dict_['state_dict'].items():
52
+ # if name not in model.stage1.state_dict().keys():
53
+ # continue
54
+ # if isinstance(param, nn.parameter.Parameter):
55
+ # param = param.data
56
+ # vqgan_stage_dict[name].copy_(param)
57
+ #
58
+ # model.stage1.load_state_dict(vqgan_stage_dict)
59
+ # #---------------------------------------------------------
60
+ # state_dict_dalle = torch.load('dalle_last.ckpt', map_location='cpu')
61
+ # dalle_stage_dict = model.stage2.state_dict()
62
+ #
63
+ # for name, param in state_dict_dalle['state_dict'].items():
64
+ # if name[6:] not in model.stage2.state_dict().keys():
65
+ # print(name)
66
+ # continue
67
+ # if isinstance(param, nn.parameter.Parameter):
68
+ # param = param.data
69
+ # dalle_stage_dict[name[6:]].copy_(param)
70
+ #
71
+ # model.stage2.load_state_dict(dalle_stage_dict)
72
+
73
+ # NEW METHOD
74
+ model.load_state_dict(torch.load(full_dict_path))
75
+ model.to(device=device)
76
 
77
  # model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
78
  # model_clip.to(device=device)