ZhangYuhan commited on
Commit
f8f1d3f
·
1 Parent(s): 6e3a72b

update server

Browse files
app.py CHANGED
@@ -6,7 +6,7 @@ from serve.gradio_web_i2s import *
6
  from serve.leaderboard import build_leaderboard_tab
7
  from model.model_manager import ModelManager
8
  from pathlib import Path
9
- from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
10
 
11
 
12
  def build_combine_demo(models, elo_results_file, leaderboard_table_file):
 
6
  from serve.leaderboard import build_leaderboard_tab
7
  from model.model_manager import ModelManager
8
  from pathlib import Path
9
+ from constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
10
 
11
 
12
  def build_combine_demo(models, elo_results_file, leaderboard_table_file):
serve/constants.py → constants.py RENAMED
@@ -1,8 +1,12 @@
1
  import os
 
 
 
2
 
3
  LOGDIR = os.getenv("LOGDIR", "./3DGen-Arena-logs/vote_log")
4
  IMAGE_DIR = os.getenv("IMAGE_DIR", f"{LOGDIR}/images")
5
  OFFLINE_DIR = "./offline"
 
6
 
7
  SERVER_PORT = os.getenv("SERVER_PORT", 7860)
8
  ROOT_PATH = os.getenv("ROOT_PATH", None)
@@ -17,4 +21,8 @@ SAVE_IMAGE = "save_image"
17
  SAVE_LOG = "save_log"
18
 
19
  NUM_SIDES = 2
20
- TEXT_PROMPT_PATH = "offline/prompts.json"
 
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
+
4
+ os.chdir(Path.cwd())
5
 
6
  LOGDIR = os.getenv("LOGDIR", "./3DGen-Arena-logs/vote_log")
7
  IMAGE_DIR = os.getenv("IMAGE_DIR", f"{LOGDIR}/images")
8
  OFFLINE_DIR = "./offline"
9
+ OFFLINE_GIF_DIR = os.path.join(OFFLINE_DIR, "gifs")
10
 
11
  SERVER_PORT = os.getenv("SERVER_PORT", 7860)
12
  ROOT_PATH = os.getenv("ROOT_PATH", None)
 
21
  SAVE_LOG = "save_log"
22
 
23
  NUM_SIDES = 2
24
+ TEXT_PROMPT_PATH = "offline/prompts_110.json"
25
+ IMAGE_PROMPT_PATH = "offline/image_urls.txt"
26
+
27
+ MAX_ATTEMPTS = 5
28
+ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9")
model/client.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio_client import Client
3
+ from constants import MAX_ATTEMPTS
4
+
5
+ class GaussianToMeshClient:
6
+ def __init__(self):
7
+ self.client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True)
8
+
9
+ def run(self, shape):
10
+ attempt = 1
11
+ mesh = None
12
+ while attempt <= MAX_ATTEMPTS:
13
+ try:
14
+ mesh = self.to_mesh_client.predict(shape, api_name="/run")
15
+ break
16
+ except Exception as e:
17
+ print(f"Attempt to convert Gaussian to Mesh Failed {attempt}/{MAX_ATTEMPTS}: {e}")
18
+ attempt += 1
19
+ # time.sleep(1)
20
+ return mesh
21
+
22
+ Gau2Mesh_client = GaussianToMeshClient()
model/model_config.py CHANGED
@@ -38,14 +38,44 @@ register_model_config(
38
  online_model=False
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  register_model_config(
42
- model_name="mvdream",
43
- i2s_model=False,
44
  online_model=False
45
  )
46
 
 
 
 
 
 
 
47
  register_model_config(
48
- model_name="prolificdreamer",
49
- i2s_model=False,
 
 
 
 
 
 
 
 
 
 
 
 
50
  online_model=False
51
  )
 
38
  online_model=False
39
  )
40
 
41
+ # register_model_config(
42
+ # model_name="mvdream",
43
+ # i2s_model=False,
44
+ # online_model=False
45
+ # )
46
+
47
+ # register_model_config(
48
+ # model_name="prolificdreamer",
49
+ # i2s_model=False,
50
+ # online_model=False
51
+ # )
52
+
53
  register_model_config(
54
+ model_name="dreamgaussian",
55
+ i2s_model=True,
56
  online_model=False
57
  )
58
 
59
+ # register_model_config(
60
+ # model_name="wonder3d",
61
+ # i2s_model=True,
62
+ # online_model=False
63
+ # )
64
+
65
  register_model_config(
66
+ model_name="lgm",
67
+ i2s_model=True,
68
+ online_model=False
69
+ )
70
+
71
+ register_model_config(
72
+ model_name="openlrm",
73
+ i2s_model=True,
74
+ online_model=False
75
+ )
76
+
77
+ register_model_config(
78
+ model_name="triplane-gaussian",
79
+ i2s_model=True,
80
  online_model=False
81
  )
model/model_manager.py CHANGED
@@ -61,47 +61,66 @@ class ModelManager:
61
  return
62
 
63
  # @spaces.GPU(duration=120)
64
- def inference(self, prompt, model_name):
 
 
 
65
  worker = self.models_worker[model_name]
66
- result = worker.inference(prompt=prompt)
 
 
 
 
 
67
  return result
68
 
69
- def render(self, prompt, model_name):
70
  worker = self.models_worker[model_name]
71
- result = worker.render(prompt=prompt)
72
  return result
73
 
74
- def inference_parallel(self, prompt, model_A, model_B):
 
 
75
  results = []
76
  model_names = [model_A, model_B]
77
  with concurrent.futures.ThreadPoolExecutor() as executor:
78
- future_to_result = {executor.submit(self.inference, prompt, model): model
79
  for model in model_names}
80
  for future in concurrent.futures.as_completed(future_to_result):
81
  result = future.result()
82
  results.append(result)
83
  return results[0], results[1]
84
 
85
- def inference_parallel_anony(self, prompt, model_A, model_B, i2s_model):
 
 
86
  if model_A == model_B == "":
87
- model_A, model_B = random.sample(self.get_models(i2s_model=i2s_model, online_model=True), 2)
 
 
 
 
 
88
  model_names = [model_A, model_B]
 
89
  results = []
90
  with concurrent.futures.ThreadPoolExecutor() as executor:
91
- future_to_result = {executor.submit(self.inference, prompt, model): model
92
  for model in model_names}
93
  for future in concurrent.futures.as_completed(future_to_result):
94
  result = future.result()
95
  results.append(result)
96
- return results[0], results[1]
97
 
98
 
99
- def render_parallel(self, prompt, model_A, model_B):
100
  results = []
101
  model_names = [model_A, model_B]
 
102
  with concurrent.futures.ThreadPoolExecutor() as executor:
103
- future_to_result = {executor.submit(self.render, prompt, model): model
104
- for model in model_names}
105
  for future in concurrent.futures.as_completed(future_to_result):
106
  result = future.result()
107
  results.append(result)
 
61
  return
62
 
63
  # @spaces.GPU(duration=120)
64
+ def inference(self,
65
+ prompt, model_name,
66
+ offline=False, offline_idx=None):
67
+ result = None
68
  worker = self.models_worker[model_name]
69
+
70
+ if offline:
71
+ result = worker.load_offline(offline, offline_idx)
72
+ if not offline or result == None:
73
+ if worker.check_online():
74
+ result = worker.inference(prompt)
75
  return result
76
 
77
+ def render(self, shape, model_name):
78
  worker = self.models_worker[model_name]
79
+ result = worker.render(shape)
80
  return result
81
 
82
+ def inference_parallel(self,
83
+ prompt, model_A, model_B,
84
+ offline=False, offline_idx=None):
85
  results = []
86
  model_names = [model_A, model_B]
87
  with concurrent.futures.ThreadPoolExecutor() as executor:
88
+ future_to_result = {executor.submit(self.inference, prompt, model, offline, offline_idx): model
89
  for model in model_names}
90
  for future in concurrent.futures.as_completed(future_to_result):
91
  result = future.result()
92
  results.append(result)
93
  return results[0], results[1]
94
 
95
+ def inference_parallel_anony(self,
96
+ prompt, model_A, model_B,
97
+ i2s_model: bool, offline: bool =False, offline_idx: int =None):
98
  if model_A == model_B == "":
99
+ if offline and i2s_model:
100
+ model_A, model_B = random.sample(self.get_i2s_models(), 2)
101
+ elif offline and not i2s_model:
102
+ model_A, model_B = random.sample(self.get_t2s_models(), 2)
103
+ else:
104
+ model_A, model_B = random.sample(self.get_models(i2s_model=i2s_model, online_model=True), 2)
105
  model_names = [model_A, model_B]
106
+
107
  results = []
108
  with concurrent.futures.ThreadPoolExecutor() as executor:
109
+ future_to_result = {executor.submit(self.inference, prompt, model, offline, offline_idx): model
110
  for model in model_names}
111
  for future in concurrent.futures.as_completed(future_to_result):
112
  result = future.result()
113
  results.append(result)
114
+ return results[0], results[1], model_A, model_B
115
 
116
 
117
+ def render_parallel(self, shape_A, model_A, shape_B, model_B):
118
  results = []
119
  model_names = [model_A, model_B]
120
+ shapes = [shape_A, shape_B]
121
  with concurrent.futures.ThreadPoolExecutor() as executor:
122
+ future_to_result = {executor.submit(self.render, shape, model): model
123
+ for model, shape in zip(model_names, shapes)}
124
  for future in concurrent.futures.as_completed(future_to_result):
125
  result = future.result()
126
  results.append(result)
model/model_registry.py CHANGED
@@ -52,13 +52,6 @@ register_model_info(
52
  "Text-to-3D using 2D Diffusion and SDS Loss",
53
  )
54
 
55
- register_model_info(
56
- ["dreamgaussian"],
57
- "DreamGaussian",
58
- "https://github.com/dreamgaussian/dreamgaussian",
59
- "Generative Gaussian Splatting for Efficient 3D Content Creation",
60
- )
61
-
62
  register_model_info(
63
  ["fantasia3d"],
64
  "Fantasia3D",
@@ -66,6 +59,13 @@ register_model_info(
66
  "Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation",
67
  )
68
 
 
 
 
 
 
 
 
69
  register_model_info(
70
  ["latent-nerf"],
71
  "Latent-NeRF",
@@ -87,6 +87,14 @@ register_model_info(
87
  "Disentangling 2D and Geometric Priors for High-Fidelity and Consistent 3D Generation",
88
  )
89
 
 
 
 
 
 
 
 
 
90
  register_model_info(
91
  ["mvdream"],
92
  "MVDream",
@@ -94,6 +102,20 @@ register_model_info(
94
  "Multi-view Diffusion for 3D Generation",
95
  )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  register_model_info(
98
  ["prolificdreamer"],
99
  "ProlificDreamer",
@@ -108,6 +130,30 @@ register_model_info(
108
  "Generating Multiview-consistent Images from a Single-view Image",
109
  )
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  register_model_info(
112
  ["wonder3d"],
113
  "Wonder3D",
@@ -122,17 +168,6 @@ register_model_info(
122
  "Hierarchical 3d generation with bootstrapped diffusion prior",
123
  )
124
 
125
-
126
- # register_model_info(
127
- # [],
128
- # "",
129
- # "",
130
- # "",
131
- # )
132
-
133
-
134
- # regist image edition models
135
-
136
  register_model_info(
137
  ["zero123"],
138
  "Zero-1-to-3",
@@ -160,14 +195,6 @@ register_model_info(
160
  "https://github.com/bytedance/ImageDream",
161
  "Image-Prompt Multi-view Diffusion for 3D Generation",
162
  )
163
-
164
- register_model_info(
165
- ["lucid-dreamer"],
166
- "LucidDreamer",
167
- "https://github.com/EnVision-Research/LucidDreamer",
168
- "Towards High-Fidelity Text-to-3D Generation via Interval Score Matching",
169
- )
170
-
171
  register_model_info(
172
  ["make-it-3d"],
173
  "Make-It-3D",
@@ -210,6 +237,27 @@ register_model_info(
210
  "Large Multi-View Gaussian Model for High-Resolution 3D Content Creation",
211
  )
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # register_model_info(
214
  # [],
215
  # "",
 
52
  "Text-to-3D using 2D Diffusion and SDS Loss",
53
  )
54
 
 
 
 
 
 
 
 
55
  register_model_info(
56
  ["fantasia3d"],
57
  "Fantasia3D",
 
59
  "Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation",
60
  )
61
 
62
+ register_model_info(
63
+ ["instant3d"],
64
+ "Instant3D",
65
+ "https://jiahao.ai/instant3d/",
66
+ "Fast Text-to-3D with Sparse-View Generation and Large Reconstruction Model",
67
+ )
68
+
69
  register_model_info(
70
  ["latent-nerf"],
71
  "Latent-NeRF",
 
87
  "Disentangling 2D and Geometric Priors for High-Fidelity and Consistent 3D Generation",
88
  )
89
 
90
+
91
+ register_model_info(
92
+ ["lucid-dreamer"],
93
+ "LucidDreamer",
94
+ "https://github.com/EnVision-Research/LucidDreamer",
95
+ "Towards High-Fidelity Text-to-3D Generation via Interval Score Matching",
96
+ )
97
+
98
  register_model_info(
99
  ["mvdream"],
100
  "MVDream",
 
102
  "Multi-view Diffusion for 3D Generation",
103
  )
104
 
105
+ register_model_info(
106
+ ["point-e"],
107
+ "Point·E",
108
+ "https://github.com/openai/point-e",
109
+ "A System for Generating 3D Point Clouds from Complex Prompts",
110
+ )
111
+
112
+ register_model_info(
113
+ ["shap-e"],
114
+ "Shap-E",
115
+ "https://github.com/openai/shap-e",
116
+ "Generating Conditional 3D Implicit Functions",
117
+ )
118
+
119
  register_model_info(
120
  ["prolificdreamer"],
121
  "ProlificDreamer",
 
130
  "Generating Multiview-consistent Images from a Single-view Image",
131
  )
132
 
133
+ register_model_info(
134
+ ["sjc"],
135
+ "Score Jacobian Chaining",
136
+ "https://pals.ttic.edu/p/score-jacobian-chaining",
137
+ "Lifting Pretrained 2D Diffusion Models for 3D Generation",
138
+ )
139
+
140
+ # register_model_info(
141
+ # [],
142
+ # "",
143
+ # "",
144
+ # "",
145
+ # )
146
+
147
+
148
+ ## regist image-to-shape generation models
149
+ register_model_info(
150
+ ["dreamgaussian"],
151
+ "DreamGaussian",
152
+ "https://github.com/dreamgaussian/dreamgaussian",
153
+ "Generative Gaussian Splatting for Efficient 3D Content Creation",
154
+ )
155
+
156
+
157
  register_model_info(
158
  ["wonder3d"],
159
  "Wonder3D",
 
168
  "Hierarchical 3d generation with bootstrapped diffusion prior",
169
  )
170
 
 
 
 
 
 
 
 
 
 
 
 
171
  register_model_info(
172
  ["zero123"],
173
  "Zero-1-to-3",
 
195
  "https://github.com/bytedance/ImageDream",
196
  "Image-Prompt Multi-view Diffusion for 3D Generation",
197
  )
 
 
 
 
 
 
 
 
198
  register_model_info(
199
  ["make-it-3d"],
200
  "Make-It-3D",
 
237
  "Large Multi-View Gaussian Model for High-Resolution 3D Content Creation",
238
  )
239
 
240
+ register_model_info(
241
+ ["gsgen"],
242
+ "GSGEN",
243
+ "https://github.com/gsgen3d/gsgen",
244
+ "Text-to-3D using Gaussian Splatting",
245
+ )
246
+
247
+ register_model_info(
248
+ ["openlrm"],
249
+ "OpenLRM",
250
+ "https://github.com/3DTopia/OpenLRM",
251
+ "Open-Source Large Reconstruction Models",
252
+ )
253
+
254
+ register_model_info(
255
+ ["hifa"],
256
+ "HiFA",
257
+ "https://github.com/JunzheJosephZhu/HiFA",
258
+ "High-fidelity Text-to-3D Generation with Advanced Diffusion Guidance",
259
+ )
260
+
261
  # register_model_info(
262
  # [],
263
  # "",
model/model_worker.py CHANGED
@@ -1,40 +1,55 @@
1
  import os
 
2
  import time
 
3
  from typing import List
4
  import replicate
 
5
 
6
- os.environ("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9")
 
 
 
 
 
 
 
7
 
8
  class BaseModelWorker:
9
  def __init__(self,
10
  model_name: str,
11
  i2s_model: bool,
12
  online_model: bool,
13
- model_path: str = None,
14
  ):
15
  self.model_name = model_name
16
  self.i2s_model = i2s_model
17
  self.online_model = online_model
18
- self.model_path = model_path
19
- self.model = None
20
 
21
- if self.online_model:
22
- assert not self.model_path, f"Please give model_path of {model_name}"
23
- self.model = self.load_model()
 
24
 
25
  def check_online(self) -> bool:
26
  if self.online_model and not self.model:
27
  return True
28
  else:
29
  return False
30
-
31
- def load_model(self):
32
- pass
33
 
 
 
 
 
 
 
 
34
  def inference(self, prompt):
35
  pass
36
 
37
- def render(self, shape):
38
  pass
39
 
40
  class HuggingfaceApiWorker(BaseModelWorker):
@@ -44,25 +59,25 @@ class HuggingfaceApiWorker(BaseModelWorker):
44
  i2s_model: bool,
45
  online_model: bool,
46
  model_api: str,
47
- model_path: str = None,
48
  ):
49
  super().__init__(
50
  model_name,
51
  i2s_model,
52
  online_model,
53
- model_path,
54
  )
55
- self.model_api = model_api
56
 
57
  class PointE_Worker(BaseModelWorker):
58
  def __init__(self,
59
  model_name: str,
60
  i2s_model: bool,
61
  online_model: bool,
62
- model_api: str,
63
- model_path: str = None):
64
- super().__init__(model_name, i2s_model, online_model, model_path)
65
- self.model_api = model_api
 
 
66
 
67
 
68
  class LGM_Worker(BaseModelWorker):
@@ -70,26 +85,69 @@ class LGM_Worker(BaseModelWorker):
70
  model_name: str,
71
  i2s_model: bool,
72
  online_model: bool,
73
- model_path: str = "camenduru/lgm-ply-to-glb:eb217314ab0d025370df16b8c9127f9ac1a0e4b3ffbff6b323d598d3c814d258"):
74
- super().__init__(model_name, i2s_model, online_model, model_path)
 
 
75
 
76
  def inference(self, image):
77
- output = replicate.run(
78
- self.model_path,
79
- input={"ply_file_url": image}
 
80
  )
81
- #=> .glb file url: "https://replicate.delivery/pbxt/r4iOSfk7cv2wACJL539ACB4E...
82
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  if __name__=="__main__":
86
- input = {
87
- "ply_file_url": "https://replicate.delivery/pbxt/UvKKgNj9mT7pIVHzwerhcjkp5cMH4FS5emPVghk2qyzMRwUSA/gradio_output.ply"
88
- }
89
- print("Start...")
90
- output = replicate.run(
91
- "camenduru/lgm-ply-to-glb:eb217314ab0d025370df16b8c9127f9ac1a0e4b3ffbff6b323d598d3c814d258",
92
- input=input
93
- )
94
- print("output: ", output)
95
- #=> "https://replicate.delivery/pbxt/r4iOSfk7cv2wACJL539ACB4E...
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
  import time
4
+ import kiui
5
  from typing import List
6
  import replicate
7
+ import subprocess
8
 
9
+ import sys
10
+ sys.path.append("..")
11
+
12
+ from gradio_client import Client
13
+ from constants import OFFLINE_GIF_DIR, MAX_ATTEMPTS, REPLICATE_API_TOKEN
14
+ from .client import Gau2Mesh_client
15
+
16
+ # os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
17
 
18
  class BaseModelWorker:
19
  def __init__(self,
20
  model_name: str,
21
  i2s_model: bool,
22
  online_model: bool,
23
+ model_api: str = None
24
  ):
25
  self.model_name = model_name
26
  self.i2s_model = i2s_model
27
  self.online_model = online_model
28
+ self.model_api = model_api
29
+ self.urls_json = None
30
 
31
+ urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json")
32
+ if os.path.exists(urls_json_path):
33
+ with open(urls_json_path, 'r') as f:
34
+ self.urls_json = json.load(f)
35
 
36
  def check_online(self) -> bool:
37
  if self.online_model and not self.model:
38
  return True
39
  else:
40
  return False
 
 
 
41
 
42
+ def load_offline(self, offline: bool, offline_idx):
43
+ ## offline
44
+ if offline and str(offline_idx) in self.urls_json.keys():
45
+ return self.urls_json[str(offline_idx)]
46
+ else:
47
+ return None
48
+
49
  def inference(self, prompt):
50
  pass
51
 
52
+ def render(self, shape, rgb_on=True, normal_on=True):
53
  pass
54
 
55
  class HuggingfaceApiWorker(BaseModelWorker):
 
59
  i2s_model: bool,
60
  online_model: bool,
61
  model_api: str,
 
62
  ):
63
  super().__init__(
64
  model_name,
65
  i2s_model,
66
  online_model,
67
+ model_api,
68
  )
 
69
 
70
  class PointE_Worker(BaseModelWorker):
71
  def __init__(self,
72
  model_name: str,
73
  i2s_model: bool,
74
  online_model: bool,
75
+ model_api: str):
76
+ super().__init__(model_name, i2s_model, online_model, model_api)
77
+
78
+ class TriplaneGaussian(BaseModelWorker):
79
+ def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None):
80
+ super().__init__(model_name, i2s_model, online_model, model_api)
81
 
82
 
83
  class LGM_Worker(BaseModelWorker):
 
85
  model_name: str,
86
  i2s_model: bool,
87
  online_model: bool,
88
+ model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
89
+ ):
90
+ super().__init__(model_name, i2s_model, online_model, model_api)
91
+ self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
92
 
93
  def inference(self, image):
94
+
95
+ output = self.model_client.run(
96
+ self.model_api,
97
+ input={"input_image": image}
98
  )
99
+ #=> .mp4 .ply
100
+ return output[1]
101
+
102
+ def render(self, shape):
103
+ mesh = Gau2Mesh_client.run(shape)
104
+
105
+ path_normal = ""
106
+ cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \
107
+ --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal"
108
+ subprocess.run(cmd_normal, shell=True, check=True)
109
+
110
+ path_rgb = ""
111
+ cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \
112
+ --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb"
113
+ subprocess.run(cmd_rgb, shell=True, check=True)
114
+
115
+ return path_normal, path_rgb
116
+
117
+ class V3D_Worker(BaseModelWorker):
118
+ def __init__(self,
119
+ model_name: str,
120
+ i2s_model: bool,
121
+ online_model: bool,
122
+ model_api: str = None):
123
+ super().__init__(model_name, i2s_model, online_model, model_api)
124
+
125
+
126
+ # model = 'LGM'
127
+ # # model = 'TriplaneGaussian'
128
+ # folder = 'glbs_full'
129
+ # form = 'glb'
130
+ # pose = '+z'
131
+
132
+ # pair = ('OpenLRM', 'meshes', 'obj', '-y')
133
+ # pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y')
134
+ # pair = ('LGM', 'glbs_full', 'glb', '+z')
135
 
136
 
137
  if __name__=="__main__":
138
+ # input = {
139
+ # "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png",
140
+ # }
141
+ # print("Start...")
142
+ # model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
143
+ # output = model_client.run(
144
+ # "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
145
+ # input=input
146
+ # )
147
+ # print("output: ", output)
148
+ #=> ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply']
149
+
150
+ output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply']
151
+ to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True)
152
+ mesh = to_mesh_client.predict(output[1], api_name="/run")
153
+ print(mesh)
requirements.txt CHANGED
@@ -25,8 +25,6 @@ gradio_client==0.14.0
25
  h11==0.14.0
26
  httpcore==1.0.4
27
  httpx==0.27.0
28
- huggingface-cli==0.1
29
- huggingface-hub==0.22.0
30
  idna==3.6
31
  importlib_resources==6.4.0
32
  Jinja2==3.1.3
 
25
  h11==0.14.0
26
  httpcore==1.0.4
27
  httpx==0.27.0
 
 
28
  idna==3.6
29
  importlib_resources==6.4.0
30
  Jinja2==3.1.3
serve/gradio_web_i2s.py CHANGED
@@ -43,7 +43,7 @@ Find out who is the 🥇conditional image generation models! More models are goi
43
 
44
  """
45
  model_list = models.get_i2s_models()
46
- gen_func = partial(generate_i2s_multi_annoy, models.inference_parallel, models.render_parallel)
47
 
48
  state_0 = gr.State()
49
  state_1 = gr.State()
@@ -140,28 +140,29 @@ Find out who is the 🥇conditional image generation models! More models are goi
140
  model_selectors = [model_selector_left, model_selector_right]
141
  results = [normal_left, rgb_left, normal_right, rgb_right]
142
 
143
- for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
 
144
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
145
 
146
  leftvote_btn.click(
147
  leftvote_last_response_anony,
148
- states + model_selectors,
149
- [imagebox] + btn_list + model_selectors
150
  )
151
  rightvote_btn.click(
152
  rightvote_last_response_anony,
153
- states + model_selectors,
154
- [imagebox] + btn_list + model_selectors
155
  )
156
  tie_btn.click(
157
  tievote_last_response_anony,
158
- states + model_selectors,
159
- [imagebox] + btn_list + model_selectors
160
  )
161
  bothbad_btn.click(
162
  bothbad_vote_last_response_anony,
163
- states + model_selectors,
164
- [imagebox] + btn_list + model_selectors
165
  )
166
 
167
  sample_btn.click(
@@ -172,6 +173,10 @@ Find out who is the 🥇conditional image generation models! More models are goi
172
  )
173
 
174
  imagebox.upload(
 
 
 
 
175
  gen_func,
176
  states + [imagebox] + model_selectors,
177
  states + results + model_selectors,
@@ -187,9 +192,9 @@ Find out who is the 🥇conditional image generation models! More models are goi
187
  )
188
 
189
  send_btn.click(
190
- sample_model,
191
- states + [model_str],
192
- states + model_selectors
193
  ).then(
194
  gen_func,
195
  states + [imagebox] + model_selectors,
@@ -221,9 +226,9 @@ Find out who is the 🥇conditional image generation models! More models are goi
221
  )
222
 
223
  regenerate_btn.click(
224
- sample_model,
225
- states + [model_str],
226
- states + model_selectors
227
  ).then(
228
  gen_func,
229
  states + [imagebox] + model_selectors,
@@ -379,28 +384,29 @@ Find out who is the 🥇conditional image generation models! More models are goi
379
  api_name="model_selector_right"
380
  )
381
 
382
- for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
 
383
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
384
 
385
  leftvote_btn.click(
386
  leftvote_last_response_named,
387
- states + model_selectors,
388
- [imagebox] + btn_list
389
  )
390
  rightvote_btn.click(
391
  rightvote_last_response_named,
392
- states + model_selectors,
393
- [imagebox] + btn_list
394
  )
395
  tie_btn.click(
396
  tievote_last_response_named,
397
- states + model_selectors,
398
- [imagebox] + btn_list
399
  )
400
  bothbad_btn.click(
401
  bothbad_vote_last_response_named,
402
- states + model_selectors,
403
- [imagebox] + btn_list
404
  )
405
 
406
  sample_btn.click(
@@ -411,9 +417,13 @@ Find out who is the 🥇conditional image generation models! More models are goi
411
  )
412
 
413
  imagebox.upload(
 
 
 
 
414
  gen_func,
415
  states + [imagebox] + model_selectors,
416
- states + results + model_selectors,
417
  api_name="submit_btn_named"
418
  ).then(
419
  enable_mds,
@@ -426,9 +436,13 @@ Find out who is the 🥇conditional image generation models! More models are goi
426
  )
427
 
428
  send_btn.click(
 
 
 
 
429
  gen_func,
430
  states + [imagebox] + model_selectors,
431
- states + results + model_selectors,
432
  api_name="send_btn_named"
433
  ).then(
434
  enable_mds,
@@ -456,6 +470,10 @@ Find out who is the 🥇conditional image generation models! More models are goi
456
  )
457
 
458
  regenerate_btn.click(
 
 
 
 
459
  gen_func,
460
  states + [imagebox] + model_selectors,
461
  states + results + model_selectors,
@@ -487,7 +505,7 @@ def build_i2s_ui_single_model(models):
487
 
488
  """
489
  model_list = models.get_i2s_models()
490
- gen_func = partial(generate_i2s, models.inference_parallel, models.render_parallel)
491
 
492
  gr.Markdown(notice_markdown, elem_id="notice_markdown")
493
 
@@ -508,40 +526,39 @@ def build_i2s_ui_single_model(models):
508
  normal = gr.Image(width=512, label = "Normal", show_download_button=True)
509
  rgb = gr.Image(width=512, label = "RGB", show_download_button=True,)
510
 
511
- with gr.Row():
512
- imagebox = gr.Image(
513
- width=512,
514
- show_label=False,
515
- visible=True,
516
- interactive=True,
517
- elem_id="input_box",
518
- )
519
- with gr.Column():
520
- # with gr.Row():
521
- sample_btn = gr.Button(value="🎲 Sample", variant="primary")
522
- send_btn = gr.Button(value="📤 Send", variant="primary")
523
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
524
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
525
-
526
-
527
  with gr.Row(elem_id="Geometry Quality"):
528
- gr.Markdown("Geometry Quality: ")
529
  geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
530
  geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
531
  geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
532
 
533
  with gr.Row(elem_id="Texture Quality"):
534
- gr.Markdown("Texture Quality: ")
535
  text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
536
  text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
537
  text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
538
 
539
  with gr.Row(elem_id="Alignment Quality"):
540
- gr.Markdown("Alignment Quality: ")
541
  align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
542
  align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
543
  align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
546
 
547
  state = gr.State()
@@ -549,24 +566,25 @@ def build_i2s_ui_single_model(models):
549
  text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
550
  align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
551
 
552
- for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
 
553
  upvote_btn, downvote_btn, flag_btn = btn_list
554
 
555
  upvote_btn.click(
556
  upvote_last_response,
557
- [state, model_selector],
558
- [imagebox] + btn_list
559
  )
560
 
561
  downvote_btn.click(
562
  downvote_last_response,
563
- [state, model_selector],
564
- [imagebox] + btn_list
565
  )
566
  flag_btn.click(
567
  flag_last_response,
568
- [state, model_selector],
569
- [imagebox] + btn_list
570
  )
571
 
572
  sample_btn.click(
@@ -577,6 +595,10 @@ def build_i2s_ui_single_model(models):
577
  )
578
 
579
  imagebox.upload(
 
 
 
 
580
  gen_func,
581
  [state, imagebox, model_selector],
582
  [state, normal, rgb],
@@ -589,6 +611,10 @@ def build_i2s_ui_single_model(models):
589
  )
590
 
591
  send_btn.click(
 
 
 
 
592
  gen_func,
593
  [state, imagebox, model_selector],
594
  [state, normal, rgb],
@@ -613,6 +639,10 @@ def build_i2s_ui_single_model(models):
613
  )
614
 
615
  regenerate_btn.click(
 
 
 
 
616
  gen_func,
617
  [state, imagebox, model_selector],
618
  [state, normal, rgb],
 
43
 
44
  """
45
  model_list = models.get_i2s_models()
46
+ gen_func = partial(generate_i2s_multi_annoy, models.inference_parallel_anony, models.render_parallel)
47
 
48
  state_0 = gr.State()
49
  state_1 = gr.State()
 
140
  model_selectors = [model_selector_left, model_selector_right]
141
  results = [normal_left, rgb_left, normal_right, rgb_right]
142
 
143
+ for btn_list, dim_md in zip([geo_btn_list, text_btn_list, align_btn_list],
144
+ [geo_md, text_md, align_md]):
145
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
146
 
147
  leftvote_btn.click(
148
  leftvote_last_response_anony,
149
+ states + model_selectors + [dim_md],
150
+ states + btn_list + model_selectors
151
  )
152
  rightvote_btn.click(
153
  rightvote_last_response_anony,
154
+ states + model_selectors + [dim_md],
155
+ states + btn_list + model_selectors
156
  )
157
  tie_btn.click(
158
  tievote_last_response_anony,
159
+ states + model_selectors + [dim_md],
160
+ states + btn_list + model_selectors
161
  )
162
  bothbad_btn.click(
163
  bothbad_vote_last_response_anony,
164
+ states + model_selectors + [dim_md],
165
+ states + btn_list + model_selectors
166
  )
167
 
168
  sample_btn.click(
 
173
  )
174
 
175
  imagebox.upload(
176
+ reset_states_side_by_side_anony,
177
+ states,
178
+ states + model_selectors + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
179
+ ).then(
180
  gen_func,
181
  states + [imagebox] + model_selectors,
182
  states + results + model_selectors,
 
192
  )
193
 
194
  send_btn.click(
195
+ reset_states_side_by_side_anony,
196
+ states,
197
+ states + model_selectors + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
198
  ).then(
199
  gen_func,
200
  states + [imagebox] + model_selectors,
 
226
  )
227
 
228
  regenerate_btn.click(
229
+ reset_states_side_by_side_anony,
230
+ states,
231
+ states + model_selectors + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
232
  ).then(
233
  gen_func,
234
  states + [imagebox] + model_selectors,
 
384
  api_name="model_selector_right"
385
  )
386
 
387
+ for btn_list, dim_md in zip([geo_btn_list, text_btn_list, align_btn_list],
388
+ [geo_md, text_md, align_md]):
389
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
390
 
391
  leftvote_btn.click(
392
  leftvote_last_response_named,
393
+ states + model_selectors + [dim_md],
394
+ states + btn_list
395
  )
396
  rightvote_btn.click(
397
  rightvote_last_response_named,
398
+ states + model_selectors + [dim_md],
399
+ states + btn_list
400
  )
401
  tie_btn.click(
402
  tievote_last_response_named,
403
+ states + model_selectors + [dim_md],
404
+ states + btn_list
405
  )
406
  bothbad_btn.click(
407
  bothbad_vote_last_response_named,
408
+ states + model_selectors + [dim_md],
409
+ states + btn_list
410
  )
411
 
412
  sample_btn.click(
 
417
  )
418
 
419
  imagebox.upload(
420
+ reset_states_side_by_side,
421
+ states,
422
+ states + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
423
+ ).then(
424
  gen_func,
425
  states + [imagebox] + model_selectors,
426
+ states + results,
427
  api_name="submit_btn_named"
428
  ).then(
429
  enable_mds,
 
436
  )
437
 
438
  send_btn.click(
439
+ reset_states_side_by_side,
440
+ states,
441
+ states + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
442
+ ).then(
443
  gen_func,
444
  states + [imagebox] + model_selectors,
445
+ states + results,
446
  api_name="send_btn_named"
447
  ).then(
448
  enable_mds,
 
470
  )
471
 
472
  regenerate_btn.click(
473
+ reset_states_side_by_side,
474
+ states,
475
+ states + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
476
+ ).then(
477
  gen_func,
478
  states + [imagebox] + model_selectors,
479
  states + results + model_selectors,
 
505
 
506
  """
507
  model_list = models.get_i2s_models()
508
+ gen_func = partial(generate_i2s, models.inference, models.render)
509
 
510
  gr.Markdown(notice_markdown, elem_id="notice_markdown")
511
 
 
526
  normal = gr.Image(width=512, label = "Normal", show_download_button=True)
527
  rgb = gr.Image(width=512, label = "RGB", show_download_button=True,)
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  with gr.Row(elem_id="Geometry Quality"):
530
+ geo_md = gr.Markdown("Geometry Quality: ")
531
  geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
532
  geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
533
  geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
534
 
535
  with gr.Row(elem_id="Texture Quality"):
536
+ text_md = gr.Markdown("Texture Quality: ")
537
  text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
538
  text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
539
  text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
540
 
541
  with gr.Row(elem_id="Alignment Quality"):
542
+ align_md = gr.Markdown("Alignment Quality: ")
543
  align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
544
  align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
545
  align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
546
 
547
+ with gr.Row():
548
+ imagebox = gr.Image(
549
+ width=512,
550
+ show_label=False,
551
+ visible=True,
552
+ interactive=True,
553
+ elem_id="input_box",
554
+ )
555
+ with gr.Column():
556
+ # with gr.Row():
557
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary")
558
+ send_btn = gr.Button(value="📤 Send", variant="primary")
559
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
560
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
561
+
562
  gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
563
 
564
  state = gr.State()
 
566
  text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
567
  align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
568
 
569
+ for btn_list, dim_md in zip([geo_btn_list, text_btn_list, align_btn_list],
570
+ [geo_md, text_md, align_md]):
571
  upvote_btn, downvote_btn, flag_btn = btn_list
572
 
573
  upvote_btn.click(
574
  upvote_last_response,
575
+ [state, model_selector, dim_md],
576
+ [state] + btn_list
577
  )
578
 
579
  downvote_btn.click(
580
  downvote_last_response,
581
+ [state, model_selector, dim_md],
582
+ [state] + btn_list
583
  )
584
  flag_btn.click(
585
  flag_last_response,
586
+ [state, model_selector, dim_md],
587
+ [state] + btn_list
588
  )
589
 
590
  sample_btn.click(
 
595
  )
596
 
597
  imagebox.upload(
598
+ reset_state,
599
+ state,
600
+ [state] + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
601
+ ).then(
602
  gen_func,
603
  [state, imagebox, model_selector],
604
  [state, normal, rgb],
 
611
  )
612
 
613
  send_btn.click(
614
+ reset_state,
615
+ state,
616
+ [state] + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
617
+ ).then(
618
  gen_func,
619
  [state, imagebox, model_selector],
620
  [state, normal, rgb],
 
639
  )
640
 
641
  regenerate_btn.click(
642
+ reset_state,
643
+ state,
644
+ [state] + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
645
+ ).then(
646
  gen_func,
647
  [state, imagebox, model_selector],
648
  [state, normal, rgb],
serve/gradio_web_t2i_anony.py CHANGED
@@ -20,7 +20,7 @@ from .inference import(
20
  generate_t2s_multi,
21
  generate_t2s_multi_annoy
22
  )
23
- from .constants import TEXT_PROMPT_PATH
24
 
25
  with open(TEXT_PROMPT_PATH, 'r') as f:
26
  prompt_list = json.load(f)
@@ -41,7 +41,7 @@ Find out who is the 🥇conditional image generation models! More models are goi
41
 
42
  """
43
  model_list = models.get_t2s_models()
44
- gen_func = partial(generate_t2s_multi_annoy, models.inference_parallel, models.render_parallel)
45
 
46
  state_0 = gr.State()
47
  state_1 = gr.State()
 
20
  generate_t2s_multi,
21
  generate_t2s_multi_annoy
22
  )
23
+ from constants import TEXT_PROMPT_PATH
24
 
25
  with open(TEXT_PROMPT_PATH, 'r') as f:
26
  prompt_list = json.load(f)
 
41
 
42
  """
43
  model_list = models.get_t2s_models()
44
+ gen_func = partial(generate_t2s_multi_annoy, models.inference_parallel_anony, models.render_parallel)
45
 
46
  state_0 = gr.State()
47
  state_1 = gr.State()
serve/gradio_web_t2i_named.py CHANGED
@@ -17,7 +17,7 @@ from .inference import(
17
  sample_prompt,
18
  generate_t2s_multi
19
  )
20
- from .constants import TEXT_PROMPT_PATH
21
 
22
  with open(TEXT_PROMPT_PATH, 'r') as f:
23
  prompt_list = json.load(f)
 
17
  sample_prompt,
18
  generate_t2s_multi
19
  )
20
+ from constants import TEXT_PROMPT_PATH
21
 
22
  with open(TEXT_PROMPT_PATH, 'r') as f:
23
  prompt_list = json.load(f)
serve/gradio_web_t2i_single.py CHANGED
@@ -11,7 +11,7 @@ from .inference import(
11
  sample_prompt,
12
  generate_t2s
13
  )
14
- from .constants import TEXT_PROMPT_PATH
15
 
16
  with open(TEXT_PROMPT_PATH, 'r') as f:
17
  prompt_list = json.load(f)
 
11
  sample_prompt,
12
  generate_t2s
13
  )
14
+ from constants import TEXT_PROMPT_PATH
15
 
16
  with open(TEXT_PROMPT_PATH, 'r') as f:
17
  prompt_list = json.load(f)
serve/gradio_web_t2s.py CHANGED
@@ -43,7 +43,7 @@ Find out who is the 🥇conditional image generation models! More models are goi
43
 
44
  """
45
  model_list = models.get_t2s_models()
46
- gen_func = partial(generate_t2s_multi_annoy, models.inference_parallel, models.render_parallel)
47
 
48
 
49
  state_0 = gr.State()
@@ -135,28 +135,29 @@ Find out who is the 🥇conditional image generation models! More models are goi
135
  model_selectors = [model_selector_left, model_selector_right]
136
  results = [normal_left, rgb_left, normal_right, rgb_right]
137
 
138
- for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
 
139
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
140
 
141
  leftvote_btn.click(
142
  leftvote_last_response_anony,
143
- states + model_selectors,
144
- [textbox] + btn_list + model_selectors
145
  )
146
  rightvote_btn.click(
147
  rightvote_last_response_anony,
148
- states + model_selectors,
149
- [textbox] + btn_list + model_selectors
150
  )
151
  tie_btn.click(
152
  tievote_last_response_anony,
153
- states + model_selectors,
154
- [textbox] + btn_list + model_selectors
155
  )
156
  bothbad_btn.click(
157
  bothbad_vote_last_response_anony,
158
- states + model_selectors,
159
- [textbox] + btn_list + model_selectors
160
  )
161
 
162
  sample_btn.click(
@@ -167,9 +168,9 @@ Find out who is the 🥇conditional image generation models! More models are goi
167
  )
168
 
169
  textbox.submit(
170
- sample_model,
171
- states + [model_str],
172
- states + model_selectors
173
  ).then(
174
  gen_func,
175
  states + [textbox] + model_selectors,
@@ -186,9 +187,9 @@ Find out who is the 🥇conditional image generation models! More models are goi
186
  )
187
 
188
  send_btn.click(
189
- sample_model,
190
- states + [model_str],
191
- states + model_selectors
192
  ).then(
193
  gen_func,
194
  states + [textbox] + model_selectors,
@@ -220,9 +221,9 @@ Find out who is the 🥇conditional image generation models! More models are goi
220
  )
221
 
222
  regenerate_btn.click(
223
- sample_model,
224
- states + [model_str],
225
- states + model_selectors
226
  ).then(
227
  gen_func,
228
  states + [textbox] + model_selectors,
@@ -379,28 +380,29 @@ Find out who is the 🥇conditional image generation models! More models are goi
379
  api_name="model_selector_right"
380
  )
381
 
382
- for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
 
383
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
384
 
385
  leftvote_btn.click(
386
  leftvote_last_response_named,
387
- states + model_selectors,
388
- [textbox] + btn_list
389
  )
390
  rightvote_btn.click(
391
  rightvote_last_response_named,
392
- states + model_selectors,
393
- [textbox] + btn_list
394
  )
395
  tie_btn.click(
396
  tievote_last_response_named,
397
- states + model_selectors,
398
- [textbox] + btn_list
399
  )
400
  bothbad_btn.click(
401
  bothbad_vote_last_response_named,
402
- states + model_selectors,
403
- [textbox] + btn_list
404
  )
405
 
406
  sample_btn.click(
@@ -411,9 +413,13 @@ Find out who is the 🥇conditional image generation models! More models are goi
411
  )
412
 
413
  textbox.submit(
 
 
 
 
414
  gen_func,
415
  states + [textbox] + model_selectors,
416
- states + results + model_selectors,
417
  api_name="submit_btn_named"
418
  ).then(
419
  enable_mds,
@@ -426,9 +432,13 @@ Find out who is the 🥇conditional image generation models! More models are goi
426
  )
427
 
428
  send_btn.click(
 
 
 
 
429
  gen_func,
430
  states + [textbox] + model_selectors,
431
- states + results + model_selectors,
432
  api_name="send_btn_named"
433
  ).then(
434
  enable_mds,
@@ -456,9 +466,13 @@ Find out who is the 🥇conditional image generation models! More models are goi
456
  )
457
 
458
  regenerate_btn.click(
 
 
 
 
459
  gen_func,
460
  states + [textbox] + model_selectors,
461
- states + results + model_selectors,
462
  api_name="regenerate_btn_named"
463
  ).then(
464
  enable_mds,
@@ -487,7 +501,7 @@ def build_t2s_ui_single_model(models):
487
 
488
  """
489
  model_list = models.get_t2s_models()
490
- gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel)
491
 
492
  gr.Markdown(notice_markdown, elem_id="notice_markdown")
493
 
@@ -507,41 +521,40 @@ def build_t2s_ui_single_model(models):
507
  with gr.Row():
508
  normal = gr.Image(width=512, label = "Normal", show_download_button=True)
509
  rgb = gr.Image(width=512, label = "RGB", show_download_button=True,)
510
-
511
- with gr.Row():
512
- textbox = gr.Textbox(
513
- show_label=False,
514
- placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
515
- container=True,
516
- elem_id="input_box",
517
- )
518
- sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
519
- send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
520
-
521
- with gr.Row():
522
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
523
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
524
-
525
 
526
  with gr.Row(elem_id="Geometry Quality"):
527
- gr.Markdown("Geometry Quality: ", elem_id="evaldim_markdown")
528
  geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
529
  geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
530
  geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
531
 
532
 
533
  with gr.Row(elem_id="Texture Quality"):
534
- gr.Markdown("Texture Quality: ", elem_id="evaldim_markdown")
535
  text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
536
  text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
537
  text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
538
 
539
  with gr.Row(elem_id="Alignment Quality"):
540
- gr.Markdown("Alignment Quality: ", elem_id="evaldim_markdown")
541
  align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
542
  align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
543
  align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
546
 
547
  state = gr.State()
@@ -549,24 +562,25 @@ def build_t2s_ui_single_model(models):
549
  text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
550
  align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
551
 
552
- for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
 
553
  upvote_btn, downvote_btn, flag_btn = btn_list
554
 
555
  upvote_btn.click(
556
  upvote_last_response,
557
- [state, model_selector],
558
- [textbox] + btn_list
559
  )
560
 
561
  downvote_btn.click(
562
  downvote_last_response,
563
- [state, model_selector],
564
- [textbox] + btn_list
565
  )
566
  flag_btn.click(
567
  flag_last_response,
568
- [state, model_selector],
569
- [textbox] + btn_list
570
  )
571
 
572
  sample_btn.click(
@@ -577,6 +591,10 @@ def build_t2s_ui_single_model(models):
577
  )
578
 
579
  textbox.submit(
 
 
 
 
580
  gen_func,
581
  [state, textbox, model_selector],
582
  [state, normal, rgb],
@@ -589,6 +607,10 @@ def build_t2s_ui_single_model(models):
589
  )
590
 
591
  send_btn.click(
 
 
 
 
592
  gen_func,
593
  [state, textbox, model_selector],
594
  [state, normal, rgb],
@@ -613,6 +635,10 @@ def build_t2s_ui_single_model(models):
613
  )
614
 
615
  regenerate_btn.click(
 
 
 
 
616
  gen_func,
617
  [state, textbox, model_selector],
618
  [state, normal, rgb],
 
43
 
44
  """
45
  model_list = models.get_t2s_models()
46
+ gen_func = partial(generate_t2s_multi_annoy, models.inference_parallel_anony, models.render_parallel)
47
 
48
 
49
  state_0 = gr.State()
 
135
  model_selectors = [model_selector_left, model_selector_right]
136
  results = [normal_left, rgb_left, normal_right, rgb_right]
137
 
138
+ for btn_list, dim_md in zip([geo_btn_list, text_btn_list, align_btn_list],
139
+ [geo_md, text_md, align_md]):
140
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
141
 
142
  leftvote_btn.click(
143
  leftvote_last_response_anony,
144
+ states + model_selectors + [dim_md],
145
+ states + btn_list + model_selectors
146
  )
147
  rightvote_btn.click(
148
  rightvote_last_response_anony,
149
+ states + model_selectors + [dim_md],
150
+ states + btn_list + model_selectors
151
  )
152
  tie_btn.click(
153
  tievote_last_response_anony,
154
+ states + model_selectors + [dim_md],
155
+ states + btn_list + model_selectors
156
  )
157
  bothbad_btn.click(
158
  bothbad_vote_last_response_anony,
159
+ states + model_selectors + [dim_md],
160
+ states + btn_list + model_selectors
161
  )
162
 
163
  sample_btn.click(
 
168
  )
169
 
170
  textbox.submit(
171
+ reset_states_side_by_side_anony,
172
+ states,
173
+ states + model_selectors + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
174
  ).then(
175
  gen_func,
176
  states + [textbox] + model_selectors,
 
187
  )
188
 
189
  send_btn.click(
190
+ reset_states_side_by_side_anony,
191
+ states,
192
+ states + model_selectors + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
193
  ).then(
194
  gen_func,
195
  states + [textbox] + model_selectors,
 
221
  )
222
 
223
  regenerate_btn.click(
224
+ reset_states_side_by_side_anony,
225
+ states,
226
+ states + model_selectors + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
227
  ).then(
228
  gen_func,
229
  states + [textbox] + model_selectors,
 
380
  api_name="model_selector_right"
381
  )
382
 
383
+ for btn_list, dim_md in zip([geo_btn_list, text_btn_list, align_btn_list],
384
+ [geo_md, text_md, align_md]):
385
  leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
386
 
387
  leftvote_btn.click(
388
  leftvote_last_response_named,
389
+ states + model_selectors + [dim_md],
390
+ states + btn_list
391
  )
392
  rightvote_btn.click(
393
  rightvote_last_response_named,
394
+ states + model_selectors + [dim_md],
395
+ states + btn_list
396
  )
397
  tie_btn.click(
398
  tievote_last_response_named,
399
+ states + model_selectors + [dim_md],
400
+ states + btn_list
401
  )
402
  bothbad_btn.click(
403
  bothbad_vote_last_response_named,
404
+ states + model_selectors + [dim_md],
405
+ states + btn_list
406
  )
407
 
408
  sample_btn.click(
 
413
  )
414
 
415
  textbox.submit(
416
+ reset_states_side_by_side,
417
+ states,
418
+ states + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
419
+ ).then(
420
  gen_func,
421
  states + [textbox] + model_selectors,
422
+ states + results,
423
  api_name="submit_btn_named"
424
  ).then(
425
  enable_mds,
 
432
  )
433
 
434
  send_btn.click(
435
+ reset_states_side_by_side,
436
+ states,
437
+ states + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
438
+ ).then(
439
  gen_func,
440
  states + [textbox] + model_selectors,
441
+ states + results,
442
  api_name="send_btn_named"
443
  ).then(
444
  enable_mds,
 
466
  )
467
 
468
  regenerate_btn.click(
469
+ reset_states_side_by_side,
470
+ states,
471
+ states + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn, geo_md, text_md, align_md]
472
+ ).then(
473
  gen_func,
474
  states + [textbox] + model_selectors,
475
+ states + results,
476
  api_name="regenerate_btn_named"
477
  ).then(
478
  enable_mds,
 
501
 
502
  """
503
  model_list = models.get_t2s_models()
504
+ gen_func = partial(generate_t2s, models.inference, models.render)
505
 
506
  gr.Markdown(notice_markdown, elem_id="notice_markdown")
507
 
 
521
  with gr.Row():
522
  normal = gr.Image(width=512, label = "Normal", show_download_button=True)
523
  rgb = gr.Image(width=512, label = "RGB", show_download_button=True,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  with gr.Row(elem_id="Geometry Quality"):
526
+ geo_md = gr.Markdown("Geometry Quality: ", elem_id="evaldim_markdown")
527
  geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
528
  geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
529
  geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
530
 
531
 
532
  with gr.Row(elem_id="Texture Quality"):
533
+ text_md = gr.Markdown("Texture Quality: ", elem_id="evaldim_markdown")
534
  text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
535
  text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
536
  text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
537
 
538
  with gr.Row(elem_id="Alignment Quality"):
539
+ align_md =gr.Markdown("Alignment Quality: ", elem_id="evaldim_markdown")
540
  align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
541
  align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
542
  align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
543
 
544
+ with gr.Row():
545
+ textbox = gr.Textbox(
546
+ show_label=False,
547
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
548
+ container=True,
549
+ elem_id="input_box",
550
+ )
551
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
552
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
553
+
554
+ with gr.Row():
555
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
556
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
557
+
558
  gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
559
 
560
  state = gr.State()
 
562
  text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
563
  align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
564
 
565
+ for btn_list, dim_md in zip([geo_btn_list, text_btn_list, align_btn_list],
566
+ [geo_md, text_md, align_md]):
567
  upvote_btn, downvote_btn, flag_btn = btn_list
568
 
569
  upvote_btn.click(
570
  upvote_last_response,
571
+ [state, model_selector, dim_md],
572
+ [state] + btn_list
573
  )
574
 
575
  downvote_btn.click(
576
  downvote_last_response,
577
+ [state, model_selector, dim_md],
578
+ [state] + btn_list
579
  )
580
  flag_btn.click(
581
  flag_last_response,
582
+ [state, model_selector, dim_md],
583
+ [state] + btn_list
584
  )
585
 
586
  sample_btn.click(
 
591
  )
592
 
593
  textbox.submit(
594
+ reset_state,
595
+ state,
596
+ [state] + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
597
+ ).then(
598
  gen_func,
599
  [state, textbox, model_selector],
600
  [state, normal, rgb],
 
607
  )
608
 
609
  send_btn.click(
610
+ reset_state,
611
+ state,
612
+ [state] + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
613
+ ).then(
614
  gen_func,
615
  [state, textbox, model_selector],
616
  [state, normal, rgb],
 
635
  )
636
 
637
  regenerate_btn.click(
638
+ reset_state,
639
+ state,
640
+ [state] + geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
641
+ ).then(
642
  gen_func,
643
  [state, textbox, model_selector],
644
  [state, normal, rgb],
serve/inference.py CHANGED
@@ -4,17 +4,27 @@ import time
4
 
5
  from .utils import *
6
  from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger
7
- from .constants import IMAGE_DIR, OFFLINE_DIR, TEXT_PROMPT_PATH
8
 
9
  with open(TEXT_PROMPT_PATH, 'r') as f:
10
  prompt_list = json.load(f)
11
 
 
 
 
 
 
 
 
 
 
12
 
13
  class State:
14
  def __init__(self,
15
  model_name, i2s_mode=False, offline=False,
16
  prompt=None, image=None, offline_idx=None,
17
- normal_video=None , rgb_video=None):
 
18
  self.conv_id = uuid.uuid4().hex
19
  self.model_name = model_name
20
  self.i2s_mode = i2s_mode
@@ -27,15 +37,18 @@ class State:
27
  self.normal_video = normal_video
28
  self.rgb_video = rgb_video
29
 
 
 
30
  def dict(self):
31
  base = {
32
  "conv_id": self.conv_id,
33
  "model_name": self.model_name,
34
  "i2s_mode": self.i2s_mode,
35
  "offline": self.offline,
36
- "prompt": self.prompt
 
37
  }
38
- if not self.offline and not self.offline_idx:
39
  base['offline_idx'] = self.offline_idx
40
  return base
41
 
@@ -91,6 +104,8 @@ def sample_prompt(state, model_name):
91
 
92
  state.model_name = model_name
93
  state.prompt = prompt
 
 
94
  return state, prompt
95
 
96
  def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1):
@@ -111,12 +126,14 @@ def sample_image(state, model_name):
111
  if state is None:
112
  state = State(model_name)
113
 
114
- idx = random.randint(0, len(prompt_list)-1)
115
- prompt = prompt_list[idx]
116
 
117
  state.model_name = model_name
118
- state.prompt = prompt
119
- return state, prompt
 
 
120
 
121
  def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1):
122
  if state_0 is None:
@@ -124,20 +141,21 @@ def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1):
124
  if state_1 is None:
125
  state_1 = State(model_name_1)
126
 
127
- idx = random.randint(0, len(prompt_list)-1)
128
- prompt = prompt_list[idx]
 
129
 
130
  state_0.offline, state_1.offline = True, True
131
  state_0.offline_idx, state_1.offline_idx = idx, idx
132
- state_0.prompt, state_1.prompt = prompt, prompt
133
- return state_0, state_1, prompt
134
 
135
  def generate_t2s(gen_func, render_func,
136
  state,
137
  text,
138
  model_name,
139
  request: gr.Request):
140
- if not text:
141
  raise gr.Warning("Prompt cannot be empty.")
142
  if not model_name:
143
  raise gr.Warning("Model name cannot be empty.")
@@ -145,11 +163,13 @@ def generate_t2s(gen_func, render_func,
145
  if state is None:
146
  state = State(model_name, i2s_mode=False, offline=False)
147
 
 
148
  ip = get_ip(request)
149
  t2s_logger.info(f"generate. ip: {ip}")
150
 
151
  state.model_name = model_name
152
  state.prompt = text
 
153
  try:
154
  idx = prompt_list.index(text)
155
  state.offline = True
@@ -158,14 +178,15 @@ def generate_t2s(gen_func, render_func,
158
  state.offline = False
159
  state.offline_idx = None
160
 
161
- if not state.offline and not state.offline_idx:
162
  start_time = time.time()
163
- normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4")
164
- rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
 
165
 
166
- state.normal_video = normal_video
167
- state.rgb_video = rgb_video
168
- yield state, normal_video, rgb_video
169
 
170
  # logger.info(f"===output===: {output}")
171
  data = {
@@ -181,13 +202,13 @@ def generate_t2s(gen_func, render_func,
181
  shape = gen_func(text, model_name)
182
  generate_time = time.time() - start_time
183
 
184
- normal_video, rgb_video = render_func(shape, model_name)
185
  finish_time = time.time()
186
  render_time = finish_time - start_time - generate_time
187
 
188
- state.normal_video = normal_video
189
- state.rgb_video = rgb_video
190
- yield state, normal_video, rgb_video
191
 
192
  # logger.info(f"===output===: {output}")
193
  data = {
@@ -217,7 +238,7 @@ def generate_t2s_multi(gen_func, render_func,
217
  text,
218
  model_name_0, model_name_1,
219
  request: gr.Request):
220
- if not text:
221
  raise gr.Warning("Prompt cannot be empty.")
222
  if not model_name_0:
223
  raise gr.Warning("Model name A cannot be empty.")
@@ -229,11 +250,13 @@ def generate_t2s_multi(gen_func, render_func,
229
  if state_1 is None:
230
  state_1 = State(model_name_1, i2s_mode=False, offline=False)
231
 
 
232
  ip = get_ip(request)
233
  t2s_multi_logger.info(f"generate. ip: {ip}")
234
 
235
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
236
  state_0.prompt, state_1.prompt = text, text
 
237
  try:
238
  idx = prompt_list.index(text)
239
  state_0.offline, state_1.offline = True, True
@@ -242,18 +265,17 @@ def generate_t2s_multi(gen_func, render_func,
242
  state_0.offline, state_1.offline = False, False
243
  state_0.offline_idx, state_1.offline_idx = None, None
244
 
245
- if not state_0.offline and not state_0.offline_idx:
246
  start_time = time.time()
247
- normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
248
- rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
249
- normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
250
- rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
 
251
 
252
- state_0.normal_video = normal_video_0
253
- state_0.rgb_video = rgb_video_0
254
- state_1.normal_video = normal_video_1
255
- state_1.rgb_video = rgb_video_1
256
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1
257
 
258
  # logger.info(f"===output===: {output}")
259
  data_0 = {
@@ -277,16 +299,13 @@ def generate_t2s_multi(gen_func, render_func,
277
  shape_0, shape_1 = gen_func(text, model_name_0, model_name_1)
278
  generate_time = time.time() - start_time
279
 
280
- normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
281
- shape_1, model_name_1)
282
  finish_time = time.time()
283
  render_time = finish_time - start_time - generate_time
284
-
285
- state_0.normal_video = normal_video_0
286
- state_0.rgb_video = rgb_video_0
287
- state_1.normal_video = normal_video_1
288
- state_1.rgb_video = rgb_video_1
289
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1
290
 
291
  # logger.info(f"===output===: {output}")
292
  data_0 = {
@@ -330,18 +349,20 @@ def generate_t2s_multi_annoy(gen_func, render_func,
330
  text,
331
  model_name_0, model_name_1,
332
  request: gr.Request):
333
- if not text:
334
  raise gr.Warning("Prompt cannot be empty.")
335
  if state_0 is None:
336
  state_0 = State(model_name_0, i2s_mode=False, offline=False)
337
  if state_1 is None:
338
  state_1 = State(model_name_1, i2s_mode=False, offline=False)
339
 
 
340
  ip = get_ip(request)
341
  t2s_multi_logger.info(f"generate. ip: {ip}")
342
 
343
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
344
  state_0.prompt, state_1.prompt = text, text
 
345
  try:
346
  idx = prompt_list.index(text)
347
  state_0.offline, state_1.offline = True, True
@@ -350,18 +371,19 @@ def generate_t2s_multi_annoy(gen_func, render_func,
350
  state_0.offline, state_1.offline = False, False
351
  state_0.offline_idx, state_1.offline_idx = None, None
352
 
353
- if not state_0.offline and not state_0.offline_idx:
354
  start_time = time.time()
355
- normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
356
- rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
357
- normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
358
- rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
 
 
359
 
360
- state_0.normal_video = normal_video_0
361
- state_0.rgb_video = rgb_video_0
362
- state_1.normal_video = normal_video_1
363
- state_1.rgb_video = rgb_video_1
364
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_1, rgb_video_1, \
365
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
366
 
367
  # logger.info(f"===output===: {output}")
@@ -383,19 +405,17 @@ def generate_t2s_multi_annoy(gen_func, render_func,
383
  }
384
  else:
385
  start_time = time.time()
386
- shape_0, shape_1 = gen_func(text, model_name_0, model_name_1)
387
  generate_time = time.time() - start_time
388
 
389
- normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
390
- shape_1, model_name_1)
391
  finish_time = time.time()
392
  render_time = finish_time - start_time - generate_time
393
-
394
- state_0.normal_video = normal_video_0
395
- state_0.rgb_video = rgb_video_0
396
- state_1.normal_video = normal_video_1
397
- state_1.rgb_video = rgb_video_1
398
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
399
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
400
 
401
  # logger.info(f"===output===: {output}")
@@ -437,7 +457,7 @@ def generate_t2s_multi_annoy(gen_func, render_func,
437
 
438
 
439
  def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request):
440
- if not image:
441
  raise gr.Warning("Image cannot be empty.")
442
  if not model_name:
443
  raise gr.Warning("Model name cannot be empty.")
@@ -449,15 +469,17 @@ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Re
449
 
450
  state.model_name = model_name
451
  state.image = image
 
452
 
453
- if not state.offline and not state.offline_idx:
454
  start_time = time.time()
455
- normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4")
456
- rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
 
457
 
458
- state.normal_video = normal_video
459
- state.rgb_video = rgb_video
460
- yield state, normal_video, rgb_video
461
 
462
  # logger.info(f"===output===: {output}")
463
  data = {
@@ -473,13 +495,13 @@ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Re
473
  shape = gen_func(image, model_name)
474
  generate_time = time.time() - start_time
475
 
476
- normal_video, rgb_video = render_func(shape, model_name)
477
  finish_time = time.time()
478
  render_time = finish_time - start_time - generate_time
479
 
480
- state.normal_video = normal_video
481
- state.rgb_video = rgb_video
482
- yield state, normal_video, rgb_video
483
 
484
  # logger.info(f"===output===: {output}")
485
  data = {
@@ -513,7 +535,7 @@ def generate_i2s_multi(gen_func, render_func,
513
  image,
514
  model_name_0, model_name_1,
515
  request: gr.Request):
516
- if not image:
517
  raise gr.Warning("Image cannot be empty.")
518
  if not model_name_0:
519
  raise gr.Warning("Model name A cannot be empty.")
@@ -530,20 +552,19 @@ def generate_i2s_multi(gen_func, render_func,
530
 
531
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
532
  state_0.image, state_1.image = image, image
 
533
 
534
- if not state_0.offline and not state_0.offline_idx and \
535
- not state_1.offline and not state_1.offline_idx:
536
  start_time = time.time()
537
- normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
538
- rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
539
- normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
540
- rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
 
541
 
542
- state_0.normal_video = normal_video_0
543
- state_0.rgb_video = rgb_video_0
544
- state_1.normal_video = normal_video_1
545
- state_1.rgb_video = rgb_video_1
546
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
547
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
548
 
549
  # logger.info(f"===output===: {output}")
@@ -568,16 +589,13 @@ def generate_i2s_multi(gen_func, render_func,
568
  shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
569
  generate_time = time.time() - start_time
570
 
571
- normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
572
- shape_1, model_name_1)
573
  finish_time = time.time()
574
  render_time = finish_time - start_time - generate_time
575
 
576
- state_0.normal_video = normal_video_0
577
- state_0.rgb_video = rgb_video_0
578
- state_1.normal_video = normal_video_1
579
- state_1.rgb_video = rgb_video_1
580
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1
581
 
582
  # logger.info(f"===output===: {output}")
583
  data_0 = {
@@ -621,12 +639,12 @@ def generate_i2s_multi(gen_func, render_func,
621
  # save_image_file_on_log_server(output_file)
622
 
623
 
624
- def generate_i2s_multi_annoy(gen_func,
625
  state_0, state_1,
626
  image,
627
  model_name_0, model_name_1,
628
  request: gr.Request):
629
- if not image:
630
  raise gr.Warning("Image cannot be empty.")
631
  if state_0 is None:
632
  state_0 = State(model_name_0, i2s_mode=True, offline=False)
@@ -638,20 +656,24 @@ def generate_i2s_multi_annoy(gen_func,
638
 
639
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
640
  state_0.image, state_1.image = image, image
 
641
 
642
- if not state_0.offline and not state_0.offline_idx and \
643
- not state_1.offline and not state_1.offline_idx:
644
  start_time = time.time()
645
- normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
646
- rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
647
- normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
648
- rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
649
-
650
- state_0.normal_video = normal_video_0
651
- state_0.rgb_video = rgb_video_0
652
- state_1.normal_video = normal_video_1
653
- state_1.rgb_video = rgb_video_1
654
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
 
 
 
 
655
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
656
 
657
  # logger.info(f"===output===: {output}")
@@ -676,16 +698,13 @@ def generate_i2s_multi_annoy(gen_func,
676
  shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
677
  generate_time = time.time() - start_time
678
 
679
- normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
680
- shape_1, model_name_1)
681
  finish_time = time.time()
682
  render_time = finish_time - start_time - generate_time
683
 
684
- state_0.normal_video = normal_video_0
685
- state_0.rgb_video = rgb_video_0
686
- state_1.normal_video = normal_video_1
687
- state_1.rgb_video = rgb_video_1
688
- yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
689
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
690
 
691
  # logger.info(f"===output===: {output}")
 
4
 
5
  from .utils import *
6
  from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger
7
+ from constants import IMAGE_DIR, OFFLINE_DIR, TEXT_PROMPT_PATH, IMAGE_PROMPT_PATH
8
 
9
  with open(TEXT_PROMPT_PATH, 'r') as f:
10
  prompt_list = json.load(f)
11
 
12
+ with open(IMAGE_PROMPT_PATH, 'r') as f:
13
+ lines = f.readlines()
14
+
15
+ image_list = {}
16
+ for line in lines:
17
+ idx = line.split('.png')[0].split('_')[-1]
18
+ url = line.split(')')[0].split('(')[-1]
19
+ image_list[eval(idx)] = url
20
+
21
 
22
  class State:
23
  def __init__(self,
24
  model_name, i2s_mode=False, offline=False,
25
  prompt=None, image=None, offline_idx=None,
26
+ normal_video=None , rgb_video=None,
27
+ evaluted_dims=0):
28
  self.conv_id = uuid.uuid4().hex
29
  self.model_name = model_name
30
  self.i2s_mode = i2s_mode
 
37
  self.normal_video = normal_video
38
  self.rgb_video = rgb_video
39
 
40
+ self.evaluted_dims = evaluted_dims
41
+
42
  def dict(self):
43
  base = {
44
  "conv_id": self.conv_id,
45
  "model_name": self.model_name,
46
  "i2s_mode": self.i2s_mode,
47
  "offline": self.offline,
48
+ "prompt": self.prompt,
49
+ "evaluted_dims": self.evaluted_dims,
50
  }
51
+ if self.offline:
52
  base['offline_idx'] = self.offline_idx
53
  return base
54
 
 
104
 
105
  state.model_name = model_name
106
  state.prompt = prompt
107
+ state.offline = True,
108
+ state.offline_idx = idx
109
  return state, prompt
110
 
111
  def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1):
 
126
  if state is None:
127
  state = State(model_name)
128
 
129
+ idx = random.sample(image_list.keys(), 1)[0]
130
+ img_url = image_list[idx]
131
 
132
  state.model_name = model_name
133
+ state.image = img_url
134
+ state.offline = True,
135
+ state.offline_idx = idx
136
+ return state, img_url
137
 
138
  def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1):
139
  if state_0 is None:
 
141
  if state_1 is None:
142
  state_1 = State(model_name_1)
143
 
144
+
145
+ idx = random.sample(image_list.keys(), 1)[0]
146
+ img_url = image_list[idx]
147
 
148
  state_0.offline, state_1.offline = True, True
149
  state_0.offline_idx, state_1.offline_idx = idx, idx
150
+ state_0.image, state_1.image = img_url, img_url
151
+ return state_0, state_1, img_url
152
 
153
  def generate_t2s(gen_func, render_func,
154
  state,
155
  text,
156
  model_name,
157
  request: gr.Request):
158
+ if not text or text.strip()=="":
159
  raise gr.Warning("Prompt cannot be empty.")
160
  if not model_name:
161
  raise gr.Warning("Model name cannot be empty.")
 
163
  if state is None:
164
  state = State(model_name, i2s_mode=False, offline=False)
165
 
166
+ text = text.strip()
167
  ip = get_ip(request)
168
  t2s_logger.info(f"generate. ip: {ip}")
169
 
170
  state.model_name = model_name
171
  state.prompt = text
172
+ state.evaluted_dims = 0
173
  try:
174
  idx = prompt_list.index(text)
175
  state.offline = True
 
178
  state.offline = False
179
  state.offline_idx = None
180
 
181
+ if state.offline and state.offline_idx:
182
  start_time = time.time()
183
+ videos = gen_func(text, model_name, offline=state.offline, offline_idx=state.offline_idx)
184
+ # normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4")
185
+ # rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
186
 
187
+ state.normal_video = videos['normal']
188
+ state.rgb_video = videos['rgb']
189
+ yield state, videos['normal'], videos['rgb']
190
 
191
  # logger.info(f"===output===: {output}")
192
  data = {
 
202
  shape = gen_func(text, model_name)
203
  generate_time = time.time() - start_time
204
 
205
+ videos = render_func(shape, model_name)
206
  finish_time = time.time()
207
  render_time = finish_time - start_time - generate_time
208
 
209
+ state.normal_video = videos['normal']
210
+ state.rgb_video = videos['rgb']
211
+ yield state, videos['normal'], videos['rgb']
212
 
213
  # logger.info(f"===output===: {output}")
214
  data = {
 
238
  text,
239
  model_name_0, model_name_1,
240
  request: gr.Request):
241
+ if not text or text.strip()=="":
242
  raise gr.Warning("Prompt cannot be empty.")
243
  if not model_name_0:
244
  raise gr.Warning("Model name A cannot be empty.")
 
250
  if state_1 is None:
251
  state_1 = State(model_name_1, i2s_mode=False, offline=False)
252
 
253
+ text = text.strip()
254
  ip = get_ip(request)
255
  t2s_multi_logger.info(f"generate. ip: {ip}")
256
 
257
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
258
  state_0.prompt, state_1.prompt = text, text
259
+ state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
260
  try:
261
  idx = prompt_list.index(text)
262
  state_0.offline, state_1.offline = True, True
 
265
  state_0.offline, state_1.offline = False, False
266
  state_0.offline_idx, state_1.offline_idx = None, None
267
 
268
+ if state_0.offline and state_0.offline_idx:
269
  start_time = time.time()
270
+ videos_0, videos_1 = gen_func(text, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx)
271
+ # normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
272
+ # rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
273
+ # normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
274
+ # rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
275
 
276
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
277
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
278
+ yield state_0, state_1,videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
 
 
279
 
280
  # logger.info(f"===output===: {output}")
281
  data_0 = {
 
299
  shape_0, shape_1 = gen_func(text, model_name_0, model_name_1)
300
  generate_time = time.time() - start_time
301
 
302
+ videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
 
303
  finish_time = time.time()
304
  render_time = finish_time - start_time - generate_time
305
+
306
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
307
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
308
+ yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
 
 
309
 
310
  # logger.info(f"===output===: {output}")
311
  data_0 = {
 
349
  text,
350
  model_name_0, model_name_1,
351
  request: gr.Request):
352
+ if not text or text.strip()=="":
353
  raise gr.Warning("Prompt cannot be empty.")
354
  if state_0 is None:
355
  state_0 = State(model_name_0, i2s_mode=False, offline=False)
356
  if state_1 is None:
357
  state_1 = State(model_name_1, i2s_mode=False, offline=False)
358
 
359
+ text = text.strip()
360
  ip = get_ip(request)
361
  t2s_multi_logger.info(f"generate. ip: {ip}")
362
 
363
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
364
  state_0.prompt, state_1.prompt = text, text
365
+ state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
366
  try:
367
  idx = prompt_list.index(text)
368
  state_0.offline, state_1.offline = True, True
 
371
  state_0.offline, state_1.offline = False, False
372
  state_0.offline_idx, state_1.offline_idx = None, None
373
 
374
+ if state_0.offline and state_0.offline_idx:
375
  start_time = time.time()
376
+ videos_0, videos_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1,
377
+ i2s_model=False, offline=state_0.offline, offline_idx=state_0.offline_idx)
378
+ # normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
379
+ # rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
380
+ # normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
381
+ # rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
382
 
383
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
384
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
385
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
386
+ yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
 
387
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
388
 
389
  # logger.info(f"===output===: {output}")
 
405
  }
406
  else:
407
  start_time = time.time()
408
+ shape_0, shape_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1)
409
  generate_time = time.time() - start_time
410
 
411
+ videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
 
412
  finish_time = time.time()
413
  render_time = finish_time - start_time - generate_time
414
+
415
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
416
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
417
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
418
+ yield state_0, state_1, videos_0[0], videos_0[1], videos_1[0], videos_1[1], \
 
419
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
420
 
421
  # logger.info(f"===output===: {output}")
 
457
 
458
 
459
  def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request):
460
+ if image is None:
461
  raise gr.Warning("Image cannot be empty.")
462
  if not model_name:
463
  raise gr.Warning("Model name cannot be empty.")
 
469
 
470
  state.model_name = model_name
471
  state.image = image
472
+ state.evaluted_dims = 0
473
 
474
+ if state.offline and state.offline_idx:
475
  start_time = time.time()
476
+ videos = gen_func(image, model_name, offline=state.offline, offline_idx=state.offline_idx)
477
+ # normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4")
478
+ # rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
479
 
480
+ state.normal_video = videos['normal']
481
+ state.rgb_video = videos['rgb']
482
+ yield state, videos['normal'], videos['rgb']
483
 
484
  # logger.info(f"===output===: {output}")
485
  data = {
 
495
  shape = gen_func(image, model_name)
496
  generate_time = time.time() - start_time
497
 
498
+ videos = render_func(shape, model_name)
499
  finish_time = time.time()
500
  render_time = finish_time - start_time - generate_time
501
 
502
+ state.normal_video = videos['normal']
503
+ state.rgb_video = videos['rgb']
504
+ yield state, videos['normal'], videos['rgb']
505
 
506
  # logger.info(f"===output===: {output}")
507
  data = {
 
535
  image,
536
  model_name_0, model_name_1,
537
  request: gr.Request):
538
+ if image is None:
539
  raise gr.Warning("Image cannot be empty.")
540
  if not model_name_0:
541
  raise gr.Warning("Model name A cannot be empty.")
 
552
 
553
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
554
  state_0.image, state_1.image = image, image
555
+ state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
556
 
557
+ if state_0.offline and state_0.offline_idx:
 
558
  start_time = time.time()
559
+ videos_0, videos_1 = gen_func(image, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx)
560
+ # normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
561
+ # rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
562
+ # normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
563
+ # rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
564
 
565
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
566
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
567
+ yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
 
 
568
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
569
 
570
  # logger.info(f"===output===: {output}")
 
589
  shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
590
  generate_time = time.time() - start_time
591
 
592
+ videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
 
593
  finish_time = time.time()
594
  render_time = finish_time - start_time - generate_time
595
 
596
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
597
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
598
+ yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
 
 
599
 
600
  # logger.info(f"===output===: {output}")
601
  data_0 = {
 
639
  # save_image_file_on_log_server(output_file)
640
 
641
 
642
+ def generate_i2s_multi_annoy(gen_func, render_func,
643
  state_0, state_1,
644
  image,
645
  model_name_0, model_name_1,
646
  request: gr.Request):
647
+ if image is None:
648
  raise gr.Warning("Image cannot be empty.")
649
  if state_0 is None:
650
  state_0 = State(model_name_0, i2s_mode=True, offline=False)
 
656
 
657
  state_0.model_name, state_1.model_name = model_name_0, model_name_1
658
  state_0.image, state_1.image = image, image
659
+ state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
660
 
661
+ if state_0.offline and state_0.offline_idx and state_1.offline and state_1.offline_idx:
 
662
  start_time = time.time()
663
+ videos_0, videos_1, model_name_0, model_name_1 = gen_func(image, model_name_0, model_name_1,
664
+ i2s_model=True, offline=state_0.offline, offline_idx=state_0.offline_idx)
665
+ # normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
666
+ # rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
667
+ # normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
668
+ # rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
669
+ print(state_0.dict())
670
+ print(state_1.dict())
671
+ print(videos_0)
672
+ print(videos_1)
673
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
674
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
675
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
676
+ yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
677
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
678
 
679
  # logger.info(f"===output===: {output}")
 
698
  shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
699
  generate_time = time.time() - start_time
700
 
701
+ videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
 
702
  finish_time = time.time()
703
  render_time = finish_time - start_time - generate_time
704
 
705
+ state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
706
+ state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
707
+ yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
 
 
708
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
709
 
710
  # logger.info(f"===output===: {output}")
serve/log_utils.py CHANGED
@@ -14,7 +14,7 @@ from pathlib import Path
14
 
15
  import requests
16
 
17
- from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
18
  from .utils import save_log_str_on_log_server
19
 
20
 
 
14
 
15
  import requests
16
 
17
+ from constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
18
  from .utils import save_log_str_on_log_server
19
 
20
 
serve/utils.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import gradio as gr
8
  from pathlib import Path
9
  from model.model_registry import *
10
- from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
11
  from typing import Union
12
 
13
 
@@ -118,6 +118,29 @@ def enable_buttons():
118
  def disable_buttons():
119
  return tuple(gr.update(interactive=False) for _ in range(11))
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def clear_t2s_history():
122
  return None, "", None, None
123
 
 
7
  import gradio as gr
8
  from pathlib import Path
9
  from model.model_registry import *
10
+ from constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
11
  from typing import Union
12
 
13
 
 
118
  def disable_buttons():
119
  return tuple(gr.update(interactive=False) for _ in range(11))
120
 
121
+ def reset_state(state):
122
+ state.normal_video, state.rgb_video = None, None
123
+ state.evaluted_dims = 0
124
+ return (state,) + tuple(gr.update(interactive=False) for _ in range(11))
125
+
126
+ def reset_states_side_by_side(state_0, state_1):
127
+ state_0.normal_video, state_0.rgb_video = None, None
128
+ state_1.normal_video, state_1.rgb_video = None, None
129
+ state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
130
+ return (state_0, state_1) \
131
+ + tuple(gr.update(visible=i>=12, interactive=False) for i in range(14)) \
132
+ + tuple(gr.update(visible=False) for _ in range(3))
133
+
134
+ def reset_states_side_by_side_anony(state_0, state_1):
135
+ state_0.model_name, state_1.model_name = "", ""
136
+ state_0.normal_video, state_0.rgb_video = None, None
137
+ state_1.normal_video, state_1.rgb_video = None, None
138
+ state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
139
+ return (state_0, state_1) \
140
+ + (gr.Markdown("", visible=False), gr.Markdown("", visible=False))\
141
+ + tuple(gr.update(visible=i>=12, interactive=False) for i in range(14)) \
142
+ + tuple(gr.update(visible=False) for _ in range(3))
143
+
144
  def clear_t2s_history():
145
  return None, "", None, None
146
 
serve/vote_utils.py CHANGED
@@ -6,17 +6,18 @@ import gradio as gr
6
  from pathlib import Path
7
  from .utils import *
8
  from .log_utils import build_logger
9
- from .constants import IMAGE_DIR
10
 
11
  t2s_logger = build_logger("gradio_web_server_text2shape", "gr_web_text2shape.log") # t2s = image generation, loggers for single model direct chat
12
  t2s_multi_logger = build_logger("gradio_web_server_text2shape_multi", "gr_web_text2shape_multi.log") # t2s_multi = image generation multi, loggers for side-by-side and battle
13
  i2s_logger = build_logger("gradio_web_server_image2shape", "gr_web_image2shape.log") # i2s = image editing, loggers for single model direct chat
14
  i2s_multi_logger = build_logger("gradio_web_server_image2shape_multi", "gr_web_image2shape_multi.log") # i2s_multi = image editing multi, loggers for side-by-side and battle
15
 
16
- def vote_last_response_t2s(state, vote_type, model_selector, request: gr.Request):
17
  with open(get_conv_log_filename(), "a") as fout:
18
  data = {
19
  "tstamp": round(time.time(), 4),
 
20
  "type": vote_type,
21
  "model": model_selector,
22
  "state": state.dict(),
@@ -25,10 +26,11 @@ def vote_last_response_t2s(state, vote_type, model_selector, request: gr.Request
25
  fout.write(json.dumps(data) + "\n")
26
  append_json_item_on_log_server(data, get_conv_log_filename())
27
 
28
- def vote_last_response_t2s_multi(states, vote_type, model_selectors, request: gr.Request):
29
  with open(get_conv_log_filename(), "a") as fout:
30
  data = {
31
  "tstamp": round(time.time(), 4),
 
32
  "type": vote_type,
33
  "models": [x for x in model_selectors],
34
  "states": [x.dict() for x in states],
@@ -42,10 +44,11 @@ def vote_last_response_t2s_multi(states, vote_type, model_selectors, request: gr
42
  # state.output.save(f, 'PNG')
43
  # save_image_file_on_log_server(output_file)
44
 
45
- def vote_last_response_i2s(state, vote_type, model_selector, request: gr.Request):
46
  with open(get_conv_log_filename(), "a") as fout:
47
  data = {
48
  "tstamp": round(time.time(), 4),
 
49
  "type": vote_type,
50
  "model": model_selector,
51
  "state": state.dict(),
@@ -62,10 +65,11 @@ def vote_last_response_i2s(state, vote_type, model_selector, request: gr.Request
62
  # save_image_file_on_log_server(output_file)
63
  # save_image_file_on_log_server(source_file)
64
 
65
- def vote_last_response_i2s_multi(states, vote_type, model_selectors, request: gr.Request):
66
  with open(get_conv_log_filename(), "a") as fout:
67
  data = {
68
  "tstamp": round(time.time(), 4),
 
69
  "type": vote_type,
70
  "models": [x for x in model_selectors],
71
  "states": [x.dict() for x in states],
@@ -85,209 +89,274 @@ def vote_last_response_i2s_multi(states, vote_type, model_selectors, request: gr
85
 
86
 
87
  ## Text-to-Shape Generation (t2s) Single Model Direct Chat
88
- def upvote_last_response_t2s(state, model_selector, request: gr.Request):
89
  ip = get_ip(request)
90
- t2s_logger.info(f"upvote. ip: {ip}")
91
- vote_last_response_t2s(state, "upvote", model_selector, request)
92
- return ("",) + (disable_btn,) * 3
 
93
 
94
- def downvote_last_response_t2s(state, model_selector, request: gr.Request):
95
  ip = get_ip(request)
96
- t2s_logger.info(f"downvote. ip: {ip}")
97
- vote_last_response_t2s(state, "downvote", model_selector, request)
98
- return ("",) + (disable_btn,) * 3
 
99
 
100
- def flag_last_response_t2s(state, model_selector, request: gr.Request):
101
  ip = get_ip(request)
102
- t2s_logger.info(f"flag. ip: {ip}")
103
- vote_last_response_t2s(state, "flag", model_selector, request)
104
- return ("",) + (disable_btn,) * 3
 
105
 
106
 
107
  ## Text-to-Shape Generation Multi (t2s_multi) Side-by-Side and Battle
108
  def leftvote_last_response_t2s_named(
109
- state0, state1, model_selector0, model_selector1, request: gr.Request
110
  ):
111
- t2s_multi_logger.info(f"leftvote (named). ip: {get_ip(request)}")
112
  vote_last_response_t2s_multi(
113
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
114
  )
115
- return ("",) + (disable_btn,) * 4
 
 
116
 
117
  def rightvote_last_response_t2s_named(
118
- state0, state1, model_selector0, model_selector1, request: gr.Request
119
  ):
120
- t2s_multi_logger.info(f"rightvote (named). ip: {get_ip(request)}")
121
  vote_last_response_t2s_multi(
122
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
123
  )
124
- return ("",) + (disable_btn,) * 4
 
 
125
 
126
  def tievote_last_response_t2s_named(
127
- state0, state1, model_selector0, model_selector1, request: gr.Request
128
  ):
129
- t2s_multi_logger.info(f"tievote (named). ip: {get_ip(request)}")
130
  vote_last_response_t2s_multi(
131
- [state0, state1], "tievote", [model_selector0, model_selector1], request
132
  )
133
- return ("",) + (disable_btn,) * 4
 
 
134
 
135
  def bothbad_vote_last_response_t2s_named(
136
- state0, state1, model_selector0, model_selector1, request: gr.Request
137
  ):
138
- t2s_multi_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
139
  vote_last_response_t2s_multi(
140
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
141
  )
142
- return ("",) + (disable_btn,) * 4
 
 
143
 
144
 
145
  def leftvote_last_response_t2s_anony(
146
- state0, state1, model_selector0, model_selector1, request: gr.Request
147
  ):
148
- t2s_multi_logger.info(f"leftvote (named). ip: {get_ip(request)}")
149
  vote_last_response_t2s_multi(
150
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
151
  )
152
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
153
- return ("",) + (disable_btn,) * 4 + names
 
 
 
 
 
 
154
 
155
  def rightvote_last_response_t2s_anony(
156
- state0, state1, model_selector0, model_selector1, request: gr.Request
157
  ):
158
- t2s_multi_logger.info(f"rightvote (named). ip: {get_ip(request)}")
159
  vote_last_response_t2s_multi(
160
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
161
  )
162
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
163
- return ("",) + (disable_btn,) * 4 + names
 
 
 
 
 
 
164
 
165
  def tievote_last_response_t2s_anony(
166
- state0, state1, model_selector0, model_selector1, request: gr.Request
167
  ):
168
- t2s_multi_logger.info(f"tievote (named). ip: {get_ip(request)}")
169
  vote_last_response_t2s_multi(
170
- [state0, state1], "tievote", [model_selector0, model_selector1], request
171
  )
172
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
173
- return ("",) + (disable_btn,) * 4 + names
174
 
 
 
 
 
 
 
 
 
175
  def bothbad_vote_last_response_t2s_anony(
176
- state0, state1, model_selector0, model_selector1, request: gr.Request
177
  ):
178
- t2s_multi_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
179
  vote_last_response_t2s_multi(
180
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
181
  )
182
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
183
- return ("",) + (disable_btn,) * 4 + names
 
 
 
 
 
 
184
 
185
  ## Image-to-Shape (i2s) Single Model Direct Chat
186
- def upvote_last_response_i2s(state, model_selector, request: gr.Request):
187
  ip = get_ip(request)
188
- i2s_logger.info(f"upvote. ip: {ip}")
189
- vote_last_response_i2s(state, "upvote", model_selector, request)
190
- return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
 
191
 
192
- def downvote_last_response_i2s(state, model_selector, request: gr.Request):
193
  ip = get_ip(request)
194
- i2s_logger.info(f"downvote. ip: {ip}")
195
- vote_last_response_i2s(state, "downvote", model_selector, request)
196
- return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
 
197
 
198
- def flag_last_response_i2s(state, model_selector, request: gr.Request):
199
  ip = get_ip(request)
200
- i2s_logger.info(f"flag. ip: {ip}")
201
- vote_last_response_i2s(state, "flag", model_selector, request)
202
- return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
 
203
 
204
 
205
  ## Image-to-Shape Multi (i2s_multi) Side-by-Side and Battle
206
  def leftvote_last_response_i2s_named(
207
- state0, state1, model_selector0, model_selector1, request: gr.Request
208
  ):
209
- i2s_multi_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
210
  vote_last_response_i2s_multi(
211
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
212
  )
213
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
 
 
214
 
215
  def rightvote_last_response_i2s_named(
216
- state0, state1, model_selector0, model_selector1, request: gr.Request
217
  ):
218
- i2s_multi_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
219
  vote_last_response_i2s_multi(
220
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
221
  )
222
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
 
 
223
 
224
  def tievote_last_response_i2s_named(
225
- state0, state1, model_selector0, model_selector1, request: gr.Request
226
  ):
227
- i2s_multi_logger.info(f"tievote (anony). ip: {get_ip(request)}")
228
  vote_last_response_i2s_multi(
229
- [state0, state1], "tievote", [model_selector0, model_selector1], request
230
  )
231
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
 
 
232
 
233
  def bothbad_vote_last_response_i2s_named(
234
- state0, state1, model_selector0, model_selector1, request: gr.Request
235
  ):
236
- i2s_multi_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
237
  vote_last_response_i2s_multi(
238
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
239
  )
240
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
 
 
241
 
242
 
243
  def leftvote_last_response_i2s_anony(
244
- state0, state1, model_selector0, model_selector1, request: gr.Request
245
  ):
246
- i2s_multi_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
247
  vote_last_response_i2s_multi(
248
- [state0, state1], "leftvote", [model_selector0, model_selector1], request
249
  )
250
- # names = (
251
- # "### Model A: " + state0.model_name,
252
- # "### Model B: " + state1.model_name,
253
- # )
254
- # names = (state0.model_name, state1.model_name)
255
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
256
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
 
257
 
 
258
  def rightvote_last_response_i2s_anony(
259
- state0, state1, model_selector0, model_selector1, request: gr.Request
260
  ):
261
- i2s_multi_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
262
  vote_last_response_i2s_multi(
263
- [state0, state1], "rightvote", [model_selector0, model_selector1], request
264
  )
265
- # names = (
266
- # "### Model A: " + state0.model_name,
267
- # "### Model B: " + state1.model_name,
268
- # )
269
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
270
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
 
 
 
271
 
272
  def tievote_last_response_i2s_anony(
273
- state0, state1, model_selector0, model_selector1, request: gr.Request
274
  ):
275
- i2s_multi_logger.info(f"tievote (anony). ip: {get_ip(request)}")
276
  vote_last_response_i2s_multi(
277
- [state0, state1], "tievote", [model_selector0, model_selector1], request
278
  )
279
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
280
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
 
 
 
 
 
 
 
281
 
282
  def bothbad_vote_last_response_i2s_anony(
283
- state0, state1, model_selector0, model_selector1, request: gr.Request
284
  ):
285
- i2s_multi_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
286
  vote_last_response_i2s_multi(
287
- [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
288
  )
289
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
290
- return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
 
 
 
 
 
 
 
291
 
292
 
293
  share_js = """
 
6
  from pathlib import Path
7
  from .utils import *
8
  from .log_utils import build_logger
9
+ from constants import IMAGE_DIR
10
 
11
  t2s_logger = build_logger("gradio_web_server_text2shape", "gr_web_text2shape.log") # t2s = image generation, loggers for single model direct chat
12
  t2s_multi_logger = build_logger("gradio_web_server_text2shape_multi", "gr_web_text2shape_multi.log") # t2s_multi = image generation multi, loggers for side-by-side and battle
13
  i2s_logger = build_logger("gradio_web_server_image2shape", "gr_web_image2shape.log") # i2s = image editing, loggers for single model direct chat
14
  i2s_multi_logger = build_logger("gradio_web_server_image2shape_multi", "gr_web_image2shape_multi.log") # i2s_multi = image editing multi, loggers for side-by-side and battle
15
 
16
+ def vote_last_response_t2s(state, dim, vote_type, model_selector, request: gr.Request):
17
  with open(get_conv_log_filename(), "a") as fout:
18
  data = {
19
  "tstamp": round(time.time(), 4),
20
+ "dim": dim,
21
  "type": vote_type,
22
  "model": model_selector,
23
  "state": state.dict(),
 
26
  fout.write(json.dumps(data) + "\n")
27
  append_json_item_on_log_server(data, get_conv_log_filename())
28
 
29
+ def vote_last_response_t2s_multi(states, dim, vote_type, model_selectors, request: gr.Request):
30
  with open(get_conv_log_filename(), "a") as fout:
31
  data = {
32
  "tstamp": round(time.time(), 4),
33
+ "dim": dim,
34
  "type": vote_type,
35
  "models": [x for x in model_selectors],
36
  "states": [x.dict() for x in states],
 
44
  # state.output.save(f, 'PNG')
45
  # save_image_file_on_log_server(output_file)
46
 
47
+ def vote_last_response_i2s(state, dim, vote_type, model_selector, request: gr.Request):
48
  with open(get_conv_log_filename(), "a") as fout:
49
  data = {
50
  "tstamp": round(time.time(), 4),
51
+ "dim": dim,
52
  "type": vote_type,
53
  "model": model_selector,
54
  "state": state.dict(),
 
65
  # save_image_file_on_log_server(output_file)
66
  # save_image_file_on_log_server(source_file)
67
 
68
+ def vote_last_response_i2s_multi(states, dim, vote_type, model_selectors, request: gr.Request):
69
  with open(get_conv_log_filename(), "a") as fout:
70
  data = {
71
  "tstamp": round(time.time(), 4),
72
+ "dim": dim,
73
  "type": vote_type,
74
  "models": [x for x in model_selectors],
75
  "states": [x.dict() for x in states],
 
89
 
90
 
91
  ## Text-to-Shape Generation (t2s) Single Model Direct Chat
92
+ def upvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
93
  ip = get_ip(request)
94
+ t2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
95
+ vote_last_response_t2s(state, dim_md, "upvote", model_selector, request)
96
+ state.evaluted_dims += 1
97
+ return (state,) + (disable_btn,) * 3
98
 
99
+ def downvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
100
  ip = get_ip(request)
101
+ t2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
102
+ vote_last_response_t2s(state, dim_md, "downvote", model_selector, request)
103
+ state.evaluted_dims += 1
104
+ return (state,) + (disable_btn,) * 3
105
 
106
+ def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
107
  ip = get_ip(request)
108
+ t2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
109
+ vote_last_response_t2s(state, dim_md, "flag", model_selector, request)
110
+ state.evaluted_dims += 1
111
+ return (state,) + (disable_btn,) * 3
112
 
113
 
114
  ## Text-to-Shape Generation Multi (t2s_multi) Side-by-Side and Battle
115
  def leftvote_last_response_t2s_named(
116
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
117
  ):
118
+ t2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
119
  vote_last_response_t2s_multi(
120
+ [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
121
  )
122
+ state0.evaluted_dims += 1
123
+ state1.evaluted_dims += 1
124
+ return (state0, state1) + (disable_btn,) * 4
125
 
126
  def rightvote_last_response_t2s_named(
127
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
128
  ):
129
+ t2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
130
  vote_last_response_t2s_multi(
131
+ [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
132
  )
133
+ state0.evaluted_dims += 1
134
+ state1.evaluted_dims += 1
135
+ return (state0, state1) + (disable_btn,) * 4
136
 
137
  def tievote_last_response_t2s_named(
138
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
139
  ):
140
+ t2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
141
  vote_last_response_t2s_multi(
142
+ [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
143
  )
144
+ state0.evaluted_dims += 1
145
+ state1.evaluted_dims += 1
146
+ return (state0, state1) + (disable_btn,) * 4
147
 
148
  def bothbad_vote_last_response_t2s_named(
149
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
150
  ):
151
+ t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
152
  vote_last_response_t2s_multi(
153
+ [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
154
  )
155
+ state0.evaluted_dims += 1
156
+ state1.evaluted_dims += 1
157
+ return (state0, state1) + (disable_btn,) * 4
158
 
159
 
160
  def leftvote_last_response_t2s_anony(
161
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
162
  ):
163
+ t2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
164
  vote_last_response_t2s_multi(
165
+ [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
166
  )
167
+
168
+ state0.evaluted_dims += 1
169
+ state1.evaluted_dims += 1
170
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
171
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
172
+ return (state0, state1) + (disable_btn,) * 4 + names
173
+ else:
174
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
175
 
176
  def rightvote_last_response_t2s_anony(
177
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
178
  ):
179
+ t2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
180
  vote_last_response_t2s_multi(
181
+ [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
182
  )
183
+
184
+ state0.evaluted_dims += 1
185
+ state1.evaluted_dims += 1
186
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
187
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
188
+ return (state0, state1) + (disable_btn,) * 4 + names
189
+ else:
190
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
191
 
192
  def tievote_last_response_t2s_anony(
193
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
194
  ):
195
+ t2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
196
  vote_last_response_t2s_multi(
197
+ [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
198
  )
 
 
199
 
200
+ state0.evaluted_dims += 1
201
+ state1.evaluted_dims += 1
202
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
203
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
204
+ return (state0, state1) + (disable_btn,) * 4 + names
205
+ else:
206
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
207
+
208
  def bothbad_vote_last_response_t2s_anony(
209
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
210
  ):
211
+ t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
212
  vote_last_response_t2s_multi(
213
+ [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
214
  )
215
+
216
+ state0.evaluted_dims += 1
217
+ state1.evaluted_dims += 1
218
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
219
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
220
+ return (state0, state1) + (disable_btn,) * 4 + names
221
+ else:
222
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
223
 
224
  ## Image-to-Shape (i2s) Single Model Direct Chat
225
+ def upvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
226
  ip = get_ip(request)
227
+ i2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
228
+ vote_last_response_i2s(state, dim_md, "upvote", model_selector, request)
229
+ state.evaluted_dims += 1
230
+ return (state,) + (disable_btn,) * 3
231
 
232
+ def downvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
233
  ip = get_ip(request)
234
+ i2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
235
+ vote_last_response_i2s(state, dim_md, "downvote", model_selector, request)
236
+ state.evaluted_dims += 1
237
+ return (state,) + (disable_btn,) * 3
238
 
239
+ def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
240
  ip = get_ip(request)
241
+ i2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
242
+ vote_last_response_i2s(state, dim_md, "flag", model_selector, request)
243
+ state.evaluted_dims += 1
244
+ return (state,) + (disable_btn,) * 3
245
 
246
 
247
  ## Image-to-Shape Multi (i2s_multi) Side-by-Side and Battle
248
  def leftvote_last_response_i2s_named(
249
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
250
  ):
251
+ i2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
252
  vote_last_response_i2s_multi(
253
+ [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
254
  )
255
+ state0.evaluted_dims += 1
256
+ state1.evaluted_dims += 1
257
+ return (state0, state1) + (disable_btn,) * 4
258
 
259
  def rightvote_last_response_i2s_named(
260
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
261
  ):
262
+ i2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
263
  vote_last_response_i2s_multi(
264
+ [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
265
  )
266
+ state0.evaluted_dims += 1
267
+ state1.evaluted_dims += 1
268
+ return (state0, state1) + (disable_btn,) * 4
269
 
270
  def tievote_last_response_i2s_named(
271
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
272
  ):
273
+ i2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
274
  vote_last_response_i2s_multi(
275
+ [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
276
  )
277
+ state0.evaluted_dims += 1
278
+ state1.evaluted_dims += 1
279
+ return (state0, state1) + (disable_btn,) * 4
280
 
281
  def bothbad_vote_last_response_i2s_named(
282
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
283
  ):
284
+ i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
285
  vote_last_response_i2s_multi(
286
+ [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
287
  )
288
+ state0.evaluted_dims += 1
289
+ state1.evaluted_dims += 1
290
+ return (state0, state1) + (disable_btn,) * 4
291
 
292
 
293
  def leftvote_last_response_i2s_anony(
294
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
295
  ):
296
+ i2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
297
  vote_last_response_i2s_multi(
298
+ [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
299
  )
300
+
301
+ state0.evaluted_dims += 1
302
+ state1.evaluted_dims += 1
303
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
304
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
305
+ return (state0, state1) + (disable_btn,) * 4 + names
306
+ else:
307
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
308
 
309
+
310
  def rightvote_last_response_i2s_anony(
311
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
312
  ):
313
+ i2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
314
  vote_last_response_i2s_multi(
315
+ [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
316
  )
317
+
318
+ state0.evaluted_dims += 1
319
+ state1.evaluted_dims += 1
320
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
321
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
322
+ return (state0, state1) + (disable_btn,) * 4 + names
323
+ else:
324
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
325
+
326
 
327
  def tievote_last_response_i2s_anony(
328
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
329
  ):
330
+ i2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
331
  vote_last_response_i2s_multi(
332
+ [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
333
  )
334
+
335
+ state0.evaluted_dims += 1
336
+ state1.evaluted_dims += 1
337
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
338
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
339
+ return (state0, state1) + (disable_btn,) * 4 + names
340
+ else:
341
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
342
+
343
 
344
  def bothbad_vote_last_response_i2s_anony(
345
+ state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
346
  ):
347
+ i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
348
  vote_last_response_i2s_multi(
349
+ [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
350
  )
351
+
352
+ state0.evaluted_dims += 1
353
+ state1.evaluted_dims += 1
354
+ if state0.evaluted_dims == state1.evaluted_dims == 3:
355
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
356
+ return (state0, state1) + (disable_btn,) * 4 + names
357
+ else:
358
+ return (state0, state1) + (disable_btn,) * 4 + ("", "")
359
+
360
 
361
 
362
  share_js = """