skytnt commited on
Commit
ec92952
ยท
1 Parent(s): 1b4fbf7
saved_model/7/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7df7925410fe5775f1d6085a548f816304c43ed2ce84835a4cf9f815b524bad5
3
+ size 1749
saved_model/7/cover.jpg ADDED

Git LFS Details

  • SHA256: cd98e72f9a5de9df03d2cffae41f907dd70116b4ae89d9fe218df6fa45cd1767
  • Pointer size: 130 Bytes
  • Size of remote file: 98.8 kB
saved_model/7/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f96e046a777407883d4665777118bdfbe0a48fc18c5fdea16c1d05eaa3af7773
3
+ size 476818993
saved_model/info.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1ae450ecf80251796929594abecca61537612c4115cf947d363c805055f0b199
3
- size 905
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43443cfa806bfad9cbb429c96ba440913d01da1bdff63daa564c824037e8070b
3
+ size 1015
utils.py CHANGED
@@ -16,211 +16,211 @@ logger = logging
16
 
17
 
18
  def load_checkpoint(checkpoint_path, model, optimizer=None):
19
- assert os.path.isfile(checkpoint_path)
20
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
- iteration = checkpoint_dict['iteration']
22
- learning_rate = checkpoint_dict['learning_rate']
23
- if optimizer is not None:
24
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
- saved_state_dict = checkpoint_dict['model']
26
- if hasattr(model, 'module'):
27
- state_dict = model.module.state_dict()
28
- else:
29
- state_dict = model.state_dict()
30
- new_state_dict= {}
31
- for k, v in state_dict.items():
32
- try:
33
- new_state_dict[k] = saved_state_dict[k]
34
- except:
35
- logger.info("%s is not in the checkpoint" % k)
36
- new_state_dict[k] = v
37
- if hasattr(model, 'module'):
38
- model.module.load_state_dict(new_state_dict)
39
- else:
40
- model.load_state_dict(new_state_dict)
41
- logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42
- checkpoint_path, iteration))
43
- return model, optimizer, learning_rate, iteration
44
 
45
 
46
  def plot_spectrogram_to_numpy(spectrogram):
47
- global MATPLOTLIB_FLAG
48
- if not MATPLOTLIB_FLAG:
49
- import matplotlib
50
- matplotlib.use("Agg")
51
- MATPLOTLIB_FLAG = True
52
- mpl_logger = logging.getLogger('matplotlib')
53
- mpl_logger.setLevel(logging.WARNING)
54
- import matplotlib.pylab as plt
55
- import numpy as np
56
-
57
- fig, ax = plt.subplots(figsize=(10,2))
58
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
59
- interpolation='none')
60
- plt.colorbar(im, ax=ax)
61
- plt.xlabel("Frames")
62
- plt.ylabel("Channels")
63
- plt.tight_layout()
64
-
65
- fig.canvas.draw()
66
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
67
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
68
- plt.close()
69
- return data
70
 
71
 
72
  def plot_alignment_to_numpy(alignment, info=None):
73
- global MATPLOTLIB_FLAG
74
- if not MATPLOTLIB_FLAG:
75
- import matplotlib
76
- matplotlib.use("Agg")
77
- MATPLOTLIB_FLAG = True
78
- mpl_logger = logging.getLogger('matplotlib')
79
- mpl_logger.setLevel(logging.WARNING)
80
- import matplotlib.pylab as plt
81
- import numpy as np
82
-
83
- fig, ax = plt.subplots(figsize=(6, 4))
84
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
85
- interpolation='none')
86
- fig.colorbar(im, ax=ax)
87
- xlabel = 'Decoder timestep'
88
- if info is not None:
89
- xlabel += '\n\n' + info
90
- plt.xlabel(xlabel)
91
- plt.ylabel('Encoder timestep')
92
- plt.tight_layout()
93
-
94
- fig.canvas.draw()
95
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
96
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
97
- plt.close()
98
- return data
99
 
100
 
101
  def load_wav_to_torch(full_path):
102
- sampling_rate, data = read(full_path)
103
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
104
 
105
 
106
  def load_filepaths_and_text(filename, split="|"):
107
- with open(filename, encoding='utf-8') as f:
108
- filepaths_and_text = [line.strip().split(split) for line in f]
109
- return filepaths_and_text
110
 
111
 
112
  def get_hparams(init=True):
113
- parser = argparse.ArgumentParser()
114
- parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
115
- help='JSON file for configuration')
116
- parser.add_argument('-m', '--model', type=str, required=True,
117
- help='Model name')
118
-
119
- args = parser.parse_args()
120
- model_dir = os.path.join("./logs", args.model)
121
-
122
- if not os.path.exists(model_dir):
123
- os.makedirs(model_dir)
124
-
125
- config_path = args.config
126
- config_save_path = os.path.join(model_dir, "config.json")
127
- if init:
128
- with open(config_path, "r") as f:
129
- data = f.read()
130
- with open(config_save_path, "w") as f:
131
- f.write(data)
132
- else:
133
- with open(config_save_path, "r") as f:
134
- data = f.read()
135
- config = json.loads(data)
136
-
137
- hparams = HParams(**config)
138
- hparams.model_dir = model_dir
139
- return hparams
140
 
141
 
142
  def get_hparams_from_dir(model_dir):
143
- config_save_path = os.path.join(model_dir, "config.json")
144
- with open(config_save_path, "r") as f:
145
- data = f.read()
146
- config = json.loads(data)
147
 
148
- hparams =HParams(**config)
149
- hparams.model_dir = model_dir
150
- return hparams
151
 
152
 
153
  def get_hparams_from_file(config_path):
154
- with open(config_path, "r") as f:
155
- data = f.read()
156
- config = json.loads(data)
157
 
158
- hparams =HParams(**config)
159
- return hparams
160
 
161
 
162
  def check_git_hash(model_dir):
163
- source_dir = os.path.dirname(os.path.realpath(__file__))
164
- if not os.path.exists(os.path.join(source_dir, ".git")):
165
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
166
- source_dir
167
- ))
168
- return
169
 
170
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
171
 
172
- path = os.path.join(model_dir, "githash")
173
- if os.path.exists(path):
174
- saved_hash = open(path).read()
175
- if saved_hash != cur_hash:
176
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
177
- saved_hash[:8], cur_hash[:8]))
178
- else:
179
- open(path, "w").write(cur_hash)
180
 
181
 
182
  def get_logger(model_dir, filename="train.log"):
183
- global logger
184
- logger = logging.getLogger(os.path.basename(model_dir))
185
- logger.setLevel(logging.DEBUG)
186
-
187
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
188
- if not os.path.exists(model_dir):
189
- os.makedirs(model_dir)
190
- h = logging.FileHandler(os.path.join(model_dir, filename))
191
- h.setLevel(logging.DEBUG)
192
- h.setFormatter(formatter)
193
- logger.addHandler(h)
194
- return logger
195
 
196
 
197
  class HParams():
198
- def __init__(self, **kwargs):
199
- for k, v in kwargs.items():
200
- if type(v) == dict:
201
- v = HParams(**v)
202
- self[k] = v
203
-
204
- def keys(self):
205
- return self.__dict__.keys()
206
 
207
- def items(self):
208
- return self.__dict__.items()
209
 
210
- def values(self):
211
- return self.__dict__.values()
212
 
213
- def __len__(self):
214
- return len(self.__dict__)
215
 
216
- def __getitem__(self, key):
217
- return getattr(self, key)
218
 
219
- def __setitem__(self, key, value):
220
- return setattr(self, key, value)
221
 
222
- def __contains__(self, key):
223
- return key in self.__dict__
224
 
225
- def __repr__(self):
226
- return self.__dict__.__repr__()
 
16
 
17
 
18
  def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
+ iteration = checkpoint_dict['iteration']
22
+ learning_rate = checkpoint_dict['learning_rate']
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
+ saved_state_dict = checkpoint_dict['model']
26
+ if hasattr(model, 'module'):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, 'module'):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
42
+ checkpoint_path, iteration))
43
+ return model, optimizer, learning_rate, iteration
44
 
45
 
46
  def plot_spectrogram_to_numpy(spectrogram):
47
+ global MATPLOTLIB_FLAG
48
+ if not MATPLOTLIB_FLAG:
49
+ import matplotlib
50
+ matplotlib.use("Agg")
51
+ MATPLOTLIB_FLAG = True
52
+ mpl_logger = logging.getLogger('matplotlib')
53
+ mpl_logger.setLevel(logging.WARNING)
54
+ import matplotlib.pylab as plt
55
+ import numpy as np
56
+
57
+ fig, ax = plt.subplots(figsize=(10, 2))
58
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
59
+ interpolation='none')
60
+ plt.colorbar(im, ax=ax)
61
+ plt.xlabel("Frames")
62
+ plt.ylabel("Channels")
63
+ plt.tight_layout()
64
+
65
+ fig.canvas.draw()
66
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
67
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
68
+ plt.close()
69
+ return data
70
 
71
 
72
  def plot_alignment_to_numpy(alignment, info=None):
73
+ global MATPLOTLIB_FLAG
74
+ if not MATPLOTLIB_FLAG:
75
+ import matplotlib
76
+ matplotlib.use("Agg")
77
+ MATPLOTLIB_FLAG = True
78
+ mpl_logger = logging.getLogger('matplotlib')
79
+ mpl_logger.setLevel(logging.WARNING)
80
+ import matplotlib.pylab as plt
81
+ import numpy as np
82
+
83
+ fig, ax = plt.subplots(figsize=(6, 4))
84
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
85
+ interpolation='none')
86
+ fig.colorbar(im, ax=ax)
87
+ xlabel = 'Decoder timestep'
88
+ if info is not None:
89
+ xlabel += '\n\n' + info
90
+ plt.xlabel(xlabel)
91
+ plt.ylabel('Encoder timestep')
92
+ plt.tight_layout()
93
+
94
+ fig.canvas.draw()
95
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
96
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
97
+ plt.close()
98
+ return data
99
 
100
 
101
  def load_wav_to_torch(full_path):
102
+ sampling_rate, data = read(full_path)
103
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
104
 
105
 
106
  def load_filepaths_and_text(filename, split="|"):
107
+ with open(filename, encoding='utf-8') as f:
108
+ filepaths_and_text = [line.strip().split(split) for line in f]
109
+ return filepaths_and_text
110
 
111
 
112
  def get_hparams(init=True):
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
115
+ help='JSON file for configuration')
116
+ parser.add_argument('-m', '--model', type=str, required=True,
117
+ help='Model name')
118
+
119
+ args = parser.parse_args()
120
+ model_dir = os.path.join("./logs", args.model)
121
+
122
+ if not os.path.exists(model_dir):
123
+ os.makedirs(model_dir)
124
+
125
+ config_path = args.config
126
+ config_save_path = os.path.join(model_dir, "config.json")
127
+ if init:
128
+ with open(config_path, "r") as f:
129
+ data = f.read()
130
+ with open(config_save_path, "w") as f:
131
+ f.write(data)
132
+ else:
133
+ with open(config_save_path, "r") as f:
134
+ data = f.read()
135
+ config = json.loads(data)
136
+
137
+ hparams = HParams(**config)
138
+ hparams.model_dir = model_dir
139
+ return hparams
140
 
141
 
142
  def get_hparams_from_dir(model_dir):
143
+ config_save_path = os.path.join(model_dir, "config.json")
144
+ with open(config_save_path, "r") as f:
145
+ data = f.read()
146
+ config = json.loads(data)
147
 
148
+ hparams = HParams(**config)
149
+ hparams.model_dir = model_dir
150
+ return hparams
151
 
152
 
153
  def get_hparams_from_file(config_path):
154
+ with open(config_path, "r", encoding="utf-8") as f:
155
+ data = f.read()
156
+ config = json.loads(data)
157
 
158
+ hparams = HParams(**config)
159
+ return hparams
160
 
161
 
162
  def check_git_hash(model_dir):
163
+ source_dir = os.path.dirname(os.path.realpath(__file__))
164
+ if not os.path.exists(os.path.join(source_dir, ".git")):
165
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
166
+ source_dir
167
+ ))
168
+ return
169
 
170
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
171
 
172
+ path = os.path.join(model_dir, "githash")
173
+ if os.path.exists(path):
174
+ saved_hash = open(path).read()
175
+ if saved_hash != cur_hash:
176
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
177
+ saved_hash[:8], cur_hash[:8]))
178
+ else:
179
+ open(path, "w").write(cur_hash)
180
 
181
 
182
  def get_logger(model_dir, filename="train.log"):
183
+ global logger
184
+ logger = logging.getLogger(os.path.basename(model_dir))
185
+ logger.setLevel(logging.DEBUG)
186
+
187
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
188
+ if not os.path.exists(model_dir):
189
+ os.makedirs(model_dir)
190
+ h = logging.FileHandler(os.path.join(model_dir, filename))
191
+ h.setLevel(logging.DEBUG)
192
+ h.setFormatter(formatter)
193
+ logger.addHandler(h)
194
+ return logger
195
 
196
 
197
  class HParams():
198
+ def __init__(self, **kwargs):
199
+ for k, v in kwargs.items():
200
+ if type(v) == dict:
201
+ v = HParams(**v)
202
+ self[k] = v
203
+
204
+ def keys(self):
205
+ return self.__dict__.keys()
206
 
207
+ def items(self):
208
+ return self.__dict__.items()
209
 
210
+ def values(self):
211
+ return self.__dict__.values()
212
 
213
+ def __len__(self):
214
+ return len(self.__dict__)
215
 
216
+ def __getitem__(self, key):
217
+ return getattr(self, key)
218
 
219
+ def __setitem__(self, key, value):
220
+ return setattr(self, key, value)
221
 
222
+ def __contains__(self, key):
223
+ return key in self.__dict__
224
 
225
+ def __repr__(self):
226
+ return self.__dict__.__repr__()