BioMike commited on
Commit
e888cf9
1 Parent(s): 538c136

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +155 -148
model.py CHANGED
@@ -1,148 +1,155 @@
1
- import json
2
- import torch
3
- import torch.nn as nn
4
- import os
5
- from pathlib import Path
6
- from typing import Optional, Union, Dict
7
- from huggingface_hub import snapshot_download
8
- import warnings
9
-
10
- class ConvVAE(nn.Module):
11
- def __init__(self, latent_size):
12
- super(ConvVAE, self).__init__()
13
-
14
- # Encoder
15
- self.encoder = nn.Sequential(
16
- nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64)
17
- nn.BatchNorm2d(64),
18
- nn.ReLU(),
19
- nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32)
20
- nn.BatchNorm2d(128),
21
- nn.ReLU(),
22
- nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16)
23
- nn.BatchNorm2d(256),
24
- nn.ReLU(),
25
- nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8)
26
- nn.BatchNorm2d(512),
27
- nn.ReLU()
28
- )
29
-
30
- self.fc_mu = nn.Linear(512 * 8 * 8, latent_size)
31
- self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size)
32
-
33
- self.fc2 = nn.Linear(latent_size, 512 * 8 * 8)
34
-
35
- self.decoder = nn.Sequential(
36
- nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16)
37
- nn.BatchNorm2d(256),
38
- nn.ReLU(),
39
- nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32)
40
- nn.BatchNorm2d(128),
41
- nn.ReLU(),
42
- nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64)
43
- nn.BatchNorm2d(64),
44
- nn.ReLU(),
45
- nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128)
46
- nn.Tanh()
47
- )
48
-
49
- def forward(self, x):
50
- mu, logvar = self.encode(x)
51
- z = self.reparameterize(mu, logvar)
52
- decoded = self.decode(z)
53
- return decoded, mu, logvar
54
-
55
- def encode(self, x):
56
- x = self.encoder(x)
57
- x = x.view(x.size(0), -1)
58
- mu = self.fc_mu(x)
59
- logvar = self.fc_logvar(x)
60
- return mu, logvar
61
-
62
- def reparameterize(self, mu, logvar):
63
- std = torch.exp(0.5 * logvar)
64
- eps = torch.randn_like(std)
65
- return mu + eps * std
66
-
67
- def decode(self, z):
68
- x = self.fc2(z)
69
- x = x.view(-1, 512, 8, 8)
70
- decoded = self.decoder(x)
71
- return decoded
72
-
73
- @classmethod
74
- def from_pretrained(
75
- cls,
76
- model_id: str,
77
- revision: Optional[str] = None,
78
- cache_dir: Optional[Union[str, Path]] = None,
79
- force_download: bool = False,
80
- proxies: Optional[Dict] = None,
81
- resume_download: bool = False,
82
- local_files_only: bool = False,
83
- token: Union[str, bool, None] = None,
84
- map_location: str = "cpu",
85
- strict: bool = False,
86
- **model_kwargs,
87
- ):
88
- """
89
- Load a pretrained model from a given model ID.
90
-
91
- Args:
92
- model_id (str): Identifier of the model to load.
93
- revision (Optional[str]): Specific model revision to use.
94
- cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models.
95
- force_download (bool): Force re-download even if the model exists.
96
- proxies (Optional[Dict]): Proxy configuration for downloads.
97
- resume_download (bool): Resume interrupted downloads.
98
- local_files_only (bool): Use only local files, don't download.
99
- token (Union[str, bool, None]): Token for API authentication.
100
- map_location (str): Device to map model to. Defaults to "cpu".
101
- strict (bool): Enforce strict state_dict loading.
102
- **model_kwargs: Additional keyword arguments for model initialization.
103
-
104
- Returns:
105
- An instance of the model loaded from the pretrained weights.
106
- """
107
- model_dir = Path(model_id)
108
- if not model_dir.exists():
109
- model_dir = Path(
110
- snapshot_download(
111
- repo_id=model_id,
112
- revision=revision,
113
- cache_dir=cache_dir,
114
- force_download=force_download,
115
- proxies=proxies,
116
- resume_download=resume_download,
117
- token=token,
118
- local_files_only=local_files_only,
119
- )
120
- )
121
-
122
- config_file = model_dir / "config.json"
123
- with open(config_file, 'r') as f:
124
- config = json.load(f)
125
-
126
- latent_size = config.get('latent_size')
127
- if latent_size is None:
128
- raise ValueError("The configuration file is missing the 'latent_size' key.")
129
-
130
- model = cls(latent_size, **model_kwargs)
131
-
132
- model_file = model_dir / "model_conv_vae_256_epoch_304.pth"
133
- if not model_file.exists():
134
- raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.")
135
-
136
- state_dict = torch.load(model_file, map_location=map_location)
137
-
138
- new_state_dict = {}
139
- for k, v in state_dict.items():
140
- if k.startswith('_orig_mod.'):
141
- new_state_dict[k[len('_orig_mod.'):]] = v
142
- else:
143
- new_state_dict[k] = v
144
-
145
- model.load_state_dict(new_state_dict, strict=strict)
146
- model.to(map_location)
147
-
148
- return model
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Optional, Union, Dict
7
+ from huggingface_hub import snapshot_download
8
+ import warnings
9
+
10
+ class ConvVAE(nn.Module):
11
+ def __init__(self, latent_size):
12
+ super(ConvVAE, self).__init__()
13
+
14
+ # Encoder
15
+ self.encoder = nn.Sequential(
16
+ nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64)
17
+ nn.BatchNorm2d(64),
18
+ nn.ReLU(),
19
+ nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32)
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16)
23
+ nn.BatchNorm2d(256),
24
+ nn.ReLU(),
25
+ nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8)
26
+ nn.BatchNorm2d(512),
27
+ nn.ReLU()
28
+ )
29
+
30
+ self.fc_mu = nn.Linear(512 * 8 * 8, latent_size)
31
+ self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size)
32
+
33
+ self.fc2 = nn.Linear(latent_size, 512 * 8 * 8)
34
+
35
+ self.decoder = nn.Sequential(
36
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16)
37
+ nn.BatchNorm2d(256),
38
+ nn.ReLU(),
39
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32)
40
+ nn.BatchNorm2d(128),
41
+ nn.ReLU(),
42
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64)
43
+ nn.BatchNorm2d(64),
44
+ nn.ReLU(),
45
+ nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128)
46
+ nn.Tanh()
47
+ )
48
+
49
+ def forward(self, x):
50
+ mu, logvar = self.encode(x)
51
+ z = self.reparameterize(mu, logvar)
52
+ decoded = self.decode(z)
53
+ return decoded, mu, logvar
54
+
55
+ def encode(self, x):
56
+ x = self.encoder(x)
57
+ x = x.view(x.size(0), -1)
58
+ mu = self.fc_mu(x)
59
+ logvar = self.fc_logvar(x)
60
+ return mu, logvar
61
+
62
+ def reparameterize(self, mu, logvar):
63
+ std = torch.exp(0.5 * logvar)
64
+ eps = torch.randn_like(std)
65
+ return mu + eps * std
66
+
67
+ def decode(self, z):
68
+ x = self.fc2(z)
69
+ x = x.view(-1, 512, 8, 8)
70
+ decoded = self.decoder(x)
71
+ return decoded
72
+
73
+ @classmethod
74
+ def from_pretrained(
75
+ cls,
76
+ model_id: str,
77
+ revision: Optional[str] = None,
78
+ cache_dir: Optional[Union[str, Path]] = None,
79
+ force_download: bool = False,
80
+ proxies: Optional[Dict] = None,
81
+ resume_download: bool = False,
82
+ local_files_only: bool = False,
83
+ token: Union[str, bool, None] = None,
84
+ map_location: str = "cpu",
85
+ strict: bool = False,
86
+ **model_kwargs,
87
+ ):
88
+ """
89
+ Load a pretrained model from a given model ID.
90
+
91
+ Args:
92
+ model_id (str): Identifier of the model to load.
93
+ revision (Optional[str]): Specific model revision to use.
94
+ cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models.
95
+ force_download (bool): Force re-download even if the model exists.
96
+ proxies (Optional[Dict]): Proxy configuration for downloads.
97
+ resume_download (bool): Resume interrupted downloads.
98
+ local_files_only (bool): Use only local files, don't download.
99
+ token (Union[str, bool, None]): Token for API authentication.
100
+ map_location (str): Device to map model to. Defaults to "cpu".
101
+ strict (bool): Enforce strict state_dict loading.
102
+ **model_kwargs: Additional keyword arguments for model initialization.
103
+
104
+ Returns:
105
+ An instance of the model loaded from the pretrained weights.
106
+ """
107
+ model_dir = Path(model_id)
108
+ if not model_dir.exists():
109
+ model_dir = Path(
110
+ snapshot_download(
111
+ repo_id=model_id,
112
+ revision=revision,
113
+ cache_dir=cache_dir,
114
+ force_download=force_download,
115
+ proxies=proxies,
116
+ resume_download=resume_download,
117
+ token=token,
118
+ local_files_only=local_files_only,
119
+ )
120
+ )
121
+
122
+ config_file = model_dir / "config.json"
123
+ with open(config_file, 'r') as f:
124
+ config = json.load(f)
125
+
126
+ latent_size = config.get('latent_size')
127
+ if latent_size is None:
128
+ raise ValueError("The configuration file is missing the 'latent_size' key.")
129
+
130
+ model = cls(latent_size, **model_kwargs)
131
+
132
+ model_file = model_dir / "model_conv_vae_256_epoch_304.pth"
133
+ if not model_file.exists():
134
+ raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.")
135
+
136
+ state_dict = torch.load(model_file, map_location=map_location)
137
+
138
+ new_state_dict = {}
139
+ for k, v in state_dict.items():
140
+ if k.startswith('_orig_mod.'):
141
+ new_state_dict[k[len('_orig_mod.'):]] = v
142
+ else:
143
+ new_state_dict[k] = v
144
+
145
+ model.load_state_dict(new_state_dict, strict=strict)
146
+ model.to(map_location)
147
+
148
+ return model
149
+
150
+
151
+ model = ConvVAE.from_pretrained(
152
+ model_id="BioMike/classical_portrait_vae",
153
+ cache_dir="./model_cache",
154
+ map_location="cpu",
155
+ strict=True).eval()