mehdidc commited on
Commit
d58b310
·
1 Parent(s): 5a52d7b
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -4,13 +4,10 @@ import gradio as gr
4
  from PIL import Image
5
  from cli import iterative_refinement
6
  from viz import grid_of_images_default
7
- # from subprocess
8
- # subprocess.call("download_models.sh", shell=True)
9
  models = {
10
  "convae": torch.load("convae.th", map_location="cpu"),
11
  "deep_convae": torch.load("deep_convae.th", map_location="cpu"),
12
  }
13
-
14
  def gen(model, seed, nb_iter, nb_samples, width, height):
15
  torch.manual_seed(int(seed))
16
  bs = 64
@@ -26,9 +23,17 @@ def gen(model, seed, nb_iter, nb_samples, width, height):
26
  grid = (grid*255).astype("uint8")
27
  return Image.fromarray(grid)
28
 
 
 
 
 
 
29
  iface = gr.Interface(
30
  fn=gen,
31
- inputs=[gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)],
 
 
 
32
  outputs="image"
33
  )
34
  iface.launch()
 
4
  from PIL import Image
5
  from cli import iterative_refinement
6
  from viz import grid_of_images_default
 
 
7
  models = {
8
  "convae": torch.load("convae.th", map_location="cpu"),
9
  "deep_convae": torch.load("deep_convae.th", map_location="cpu"),
10
  }
 
11
  def gen(model, seed, nb_iter, nb_samples, width, height):
12
  torch.manual_seed(int(seed))
13
  bs = 64
 
23
  grid = (grid*255).astype("uint8")
24
  return Image.fromarray(grid)
25
 
26
+ text = """
27
+ Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`)
28
+
29
+ These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
30
+ """
31
  iface = gr.Interface(
32
  fn=gen,
33
+ inputs=[
34
+ gr.Markdown(text),
35
+ gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)
36
+ ],
37
  outputs="image"
38
  )
39
  iface.launch()