add info
Browse files
app.py
CHANGED
@@ -4,13 +4,10 @@ import gradio as gr
|
|
4 |
from PIL import Image
|
5 |
from cli import iterative_refinement
|
6 |
from viz import grid_of_images_default
|
7 |
-
# from subprocess
|
8 |
-
# subprocess.call("download_models.sh", shell=True)
|
9 |
models = {
|
10 |
"convae": torch.load("convae.th", map_location="cpu"),
|
11 |
"deep_convae": torch.load("deep_convae.th", map_location="cpu"),
|
12 |
}
|
13 |
-
|
14 |
def gen(model, seed, nb_iter, nb_samples, width, height):
|
15 |
torch.manual_seed(int(seed))
|
16 |
bs = 64
|
@@ -26,9 +23,17 @@ def gen(model, seed, nb_iter, nb_samples, width, height):
|
|
26 |
grid = (grid*255).astype("uint8")
|
27 |
return Image.fromarray(grid)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
iface = gr.Interface(
|
30 |
fn=gen,
|
31 |
-
inputs=[
|
|
|
|
|
|
|
32 |
outputs="image"
|
33 |
)
|
34 |
iface.launch()
|
|
|
4 |
from PIL import Image
|
5 |
from cli import iterative_refinement
|
6 |
from viz import grid_of_images_default
|
|
|
|
|
7 |
models = {
|
8 |
"convae": torch.load("convae.th", map_location="cpu"),
|
9 |
"deep_convae": torch.load("deep_convae.th", map_location="cpu"),
|
10 |
}
|
|
|
11 |
def gen(model, seed, nb_iter, nb_samples, width, height):
|
12 |
torch.manual_seed(int(seed))
|
13 |
bs = 64
|
|
|
23 |
grid = (grid*255).astype("uint8")
|
24 |
return Image.fromarray(grid)
|
25 |
|
26 |
+
text = """
|
27 |
+
Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`)
|
28 |
+
|
29 |
+
These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
|
30 |
+
"""
|
31 |
iface = gr.Interface(
|
32 |
fn=gen,
|
33 |
+
inputs=[
|
34 |
+
gr.Markdown(text),
|
35 |
+
gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)
|
36 |
+
],
|
37 |
outputs="image"
|
38 |
)
|
39 |
iface.launch()
|