radames commited on
Commit
fb910bc
·
1 Parent(s): e9f675f

fix zero add error

Browse files
Files changed (2) hide show
  1. interface/app.py +18 -7
  2. interface/model_loader.py +2 -0
interface/app.py CHANGED
@@ -47,6 +47,9 @@ async () => {
47
  }
48
  """
49
 
 
 
 
50
 
51
  def cv_to_pil(img):
52
  return Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
@@ -59,7 +62,10 @@ def random_sample(model_name: str):
59
  return pil_img, model_name, latents
60
 
61
 
62
- def transform(model_state, latents_state, dxdysxsy="{}", dz=0):
 
 
 
63
  data = json.loads(dxdysxsy)
64
 
65
  model = models[model_state]
@@ -97,7 +103,8 @@ def image_click(evt: gr.SelectData):
97
  with gr.Blocks() as block:
98
  model_state = gr.State(value="cat")
99
  latents_state = gr.State({})
100
- gr.Markdown("""# UserControllableLT: User Controllable Latent Transformer
 
101
  Unofficial Gradio Demo
102
 
103
  **Author**: Yuki Endo\\
@@ -107,7 +114,8 @@ Unofficial Gradio Demo
107
  <small>
108
  Double click to add or remove stop points.
109
  <small>
110
- """)
 
111
 
112
  with gr.Row():
113
  with gr.Column():
@@ -121,10 +129,13 @@ Double click to add or remove stop points.
121
  reset_btn = gr.Button("Reset")
122
  change_style_bt = gr.Button("Change style")
123
  dxdysxsy = gr.Textbox(
124
- label="dxdysxsy", value="{}", elem_id="dxdysxsy", visible=False
 
 
 
125
  )
126
  dz = gr.Slider(
127
- minimum=-5, maximum=5, step_size=0.01, label="zoom", value=0.0
128
  )
129
  image = gr.Image(type="pil", visible=False)
130
 
@@ -164,5 +175,5 @@ Double click to add or remove stop points.
164
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
165
  )
166
 
167
- block.queue()
168
- block.launch()
 
47
  }
48
  """
49
 
50
+ default_dxdysxsy = json.dumps(
51
+ {"dx": 1, "dy": 0, "sx": 128, "sy": 128, "stopPoints": []}
52
+ )
53
 
54
  def cv_to_pil(img):
55
  return Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
 
62
  return pil_img, model_name, latents
63
 
64
 
65
+ def transform(model_state, latents_state, dxdysxsy=default_dxdysxsy, dz=0):
66
+ if "w1" not in latents_state or "w1_initial" not in latents_state:
67
+ raise gr.Error("Generate a random sample first")
68
+
69
  data = json.loads(dxdysxsy)
70
 
71
  model = models[model_state]
 
103
  with gr.Blocks() as block:
104
  model_state = gr.State(value="cat")
105
  latents_state = gr.State({})
106
+ gr.Markdown(
107
+ """# UserControllableLT: User Controllable Latent Transformer
108
  Unofficial Gradio Demo
109
 
110
  **Author**: Yuki Endo\\
 
114
  <small>
115
  Double click to add or remove stop points.
116
  <small>
117
+ """
118
+ )
119
 
120
  with gr.Row():
121
  with gr.Column():
 
129
  reset_btn = gr.Button("Reset")
130
  change_style_bt = gr.Button("Change style")
131
  dxdysxsy = gr.Textbox(
132
+ label="dxdysxsy",
133
+ value=default_dxdysxsy,
134
+ elem_id="dxdysxsy",
135
+ visible=False,
136
  )
137
  dz = gr.Slider(
138
+ minimum=-15, maximum=15, step_size=0.01, label="zoom", value=0.0
139
  )
140
  image = gr.Image(type="pil", visible=False)
141
 
 
175
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
176
  )
177
 
178
+ block.queue(concurrency_count=4, max_size=20)
179
+ block.launch(show_api=False)
interface/model_loader.py CHANGED
@@ -84,6 +84,8 @@ class Model:
84
 
85
  dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
86
  dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
 
 
87
  dxyz[:2] = dxyz[:2] / dxy_norm
88
  vec_num = dxy_norm / 10
89
 
 
84
 
85
  dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
86
  dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
87
+ epsilon = 1e-8
88
+ dxy_norm = dxy_norm + epsilon
89
  dxyz[:2] = dxyz[:2] / dxy_norm
90
  vec_num = dxy_norm / 10
91