mehdidc commited on
Commit
383cba8
·
1 Parent(s): df95636

minor bug + better defaults in test()

Browse files
Files changed (2) hide show
  1. cli.py +6 -5
  2. viz.py +2 -2
cli.py CHANGED
@@ -223,7 +223,7 @@ def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walk
223
  nb_updates += 1
224
 
225
 
226
- def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_generate=100, tsne=False):
227
  if not os.path.exists(folder):
228
  os.makedirs(folder, exist_ok=True)
229
  dataset = load_dataset(dataset, split='train')
@@ -235,6 +235,7 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
235
  model_path = os.path.join(folder, "model.th")
236
  ae = torch.load(model_path, map_location="cpu")
237
  ae = ae.to(device)
 
238
  def enc(X):
239
  batch_size = 64
240
  h_list = []
@@ -267,12 +268,12 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
267
  np.savez('{}/generated.npz'.format(folder), X=g.numpy())
268
  g_subset = g[:, 0:100]
269
  gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
270
- imsave('{}/gen_full_iters.png'.format(folder), gr)
271
 
272
  g = g[-1] # last iter
273
  print(g.shape)
274
  gr = grid_of_images_default(g.numpy())
275
- imsave('{}/gen_full.png'.format(folder), gr)
276
 
277
  if tsne:
278
  from sklearn.manifold import TSNE
@@ -300,13 +301,13 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
300
  print('fit tsne...')
301
  ah = sne.fit_transform(ah)
302
  print('grid embedding...')
303
-
304
  asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
305
  ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
306
  rows = grid_embedding(ahsmall)
307
  asmall = asmall[rows]
308
  gr = grid_of_images_default(asmall)
309
- imsave('{}/sne_grid.png'.format(folder), gr)
310
 
311
  fig = plt.figure(figsize=(10, 10))
312
  plot_dataset(ah, labels)
 
223
  nb_updates += 1
224
 
225
 
226
+ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=25, nb_generate=100, nb_active=160, tsne=False):
227
  if not os.path.exists(folder):
228
  os.makedirs(folder, exist_ok=True)
229
  dataset = load_dataset(dataset, split='train')
 
235
  model_path = os.path.join(folder, "model.th")
236
  ae = torch.load(model_path, map_location="cpu")
237
  ae = ae.to(device)
238
+ ae.nb_active = nb_active # for fc_sparse.th only
239
  def enc(X):
240
  batch_size = 64
241
  h_list = []
 
268
  np.savez('{}/generated.npz'.format(folder), X=g.numpy())
269
  g_subset = g[:, 0:100]
270
  gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
271
+ imsave('{}/gen_full_iters.png'.format(folder), (gr*255).astype("uint8") )
272
 
273
  g = g[-1] # last iter
274
  print(g.shape)
275
  gr = grid_of_images_default(g.numpy())
276
+ imsave('{}/gen_full.png'.format(folder), (gr*255).astype("uint8") )
277
 
278
  if tsne:
279
  from sklearn.manifold import TSNE
 
301
  print('fit tsne...')
302
  ah = sne.fit_transform(ah)
303
  print('grid embedding...')
304
+ assert nb_generate >= 450
305
  asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
306
  ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
307
  rows = grid_embedding(ahsmall)
308
  asmall = asmall[rows]
309
  gr = grid_of_images_default(asmall)
310
+ imsave('{}/sne_grid.png'.format(folder), (gr*255).astype("uint8") )
311
 
312
  fig = plt.figure(figsize=(10, 10))
313
  plot_dataset(ah, labels)
viz.py CHANGED
@@ -116,8 +116,8 @@ def grid_of_images(M, border=0, bordercolor=[0.0, 0.0, 0.0], shape=None, normali
116
  height, width, color = M[0].shape
117
  assert color == 3, 'Nb of color channels are {}'.format(color)
118
  if shape is None:
119
- n0 = np.int(np.ceil(np.sqrt(numimages)))
120
- n1 = np.int(np.ceil(np.sqrt(numimages)))
121
  else:
122
  n0 = shape[0]
123
  n1 = shape[1]
 
116
  height, width, color = M[0].shape
117
  assert color == 3, 'Nb of color channels are {}'.format(color)
118
  if shape is None:
119
+ n0 = np.int32(np.ceil(np.sqrt(numimages)))
120
+ n1 = np.int32(np.ceil(np.sqrt(numimages)))
121
  else:
122
  n0 = shape[0]
123
  n1 = shape[1]