Bbmyy commited on
Commit
850b1ec
1 Parent(s): b42403b

Update space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. model/__init__.py +0 -0
  2. model/__pycache__/__init__.cpython-310.pyc +0 -0
  3. model/__pycache__/__init__.cpython-39.pyc +0 -0
  4. model/__pycache__/matchmaker.cpython-310.pyc +0 -0
  5. model/__pycache__/model_manager.cpython-310.pyc +0 -0
  6. model/__pycache__/model_registry.cpython-310.pyc +0 -0
  7. model/__pycache__/model_registry.cpython-39.pyc +0 -0
  8. model/matchmaker.py +126 -0
  9. model/matchmaker_video.py +136 -0
  10. model/model_manager.py +187 -0
  11. model/model_registry.py +70 -0
  12. model/models/__init__.py +78 -0
  13. model/models/__pycache__/__init__.cpython-310.pyc +0 -0
  14. model/models/__pycache__/huggingface_models.cpython-310.pyc +0 -0
  15. model/models/__pycache__/openai_api_models.cpython-310.pyc +0 -0
  16. model/models/__pycache__/other_api_models.cpython-310.pyc +0 -0
  17. model/models/__pycache__/replicate_api_models.cpython-310.pyc +0 -0
  18. model/models/huggingface_models.py +59 -0
  19. model/models/openai_api_models.py +57 -0
  20. model/models/other_api_models.py +91 -0
  21. model/models/replicate_api_models.py +195 -0
  22. serve/Arial.ttf +0 -0
  23. serve/Ksort.py +411 -0
  24. serve/__init__.py +0 -0
  25. serve/__pycache__/Ksort.cpython-310.pyc +0 -0
  26. serve/__pycache__/Ksort.cpython-39.pyc +0 -0
  27. serve/__pycache__/__init__.cpython-310.pyc +0 -0
  28. serve/__pycache__/__init__.cpython-39.pyc +0 -0
  29. serve/__pycache__/constants.cpython-310.pyc +0 -0
  30. serve/__pycache__/constants.cpython-39.pyc +0 -0
  31. serve/__pycache__/gradio_web.cpython-310.pyc +0 -0
  32. serve/__pycache__/gradio_web.cpython-39.pyc +0 -0
  33. serve/__pycache__/gradio_web_bbox.cpython-310.pyc +0 -0
  34. serve/__pycache__/leaderboard.cpython-310.pyc +0 -0
  35. serve/__pycache__/log_utils.cpython-310.pyc +0 -0
  36. serve/__pycache__/log_utils.cpython-39.pyc +0 -0
  37. serve/__pycache__/update_skill.cpython-310.pyc +0 -0
  38. serve/__pycache__/upload.cpython-310.pyc +0 -0
  39. serve/__pycache__/upload.cpython-39.pyc +0 -0
  40. serve/__pycache__/utils.cpython-310.pyc +0 -0
  41. serve/__pycache__/utils.cpython-39.pyc +0 -0
  42. serve/__pycache__/vote_utils.cpython-310.pyc +0 -0
  43. serve/__pycache__/vote_utils.cpython-39.pyc +0 -0
  44. serve/button.css +24 -0
  45. serve/constants.py +63 -0
  46. serve/gradio_web.py +789 -0
  47. serve/gradio_web_bbox.py +492 -0
  48. serve/leaderboard.py +200 -0
  49. serve/log_server.py +86 -0
  50. serve/log_utils.py +142 -0
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
model/__pycache__/matchmaker.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
model/__pycache__/model_manager.cpython-310.pyc ADDED
Binary file (8.28 kB). View file
 
model/__pycache__/model_registry.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
model/__pycache__/model_registry.cpython-39.pyc ADDED
Binary file (1.81 kB). View file
 
model/matchmaker.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ from trueskill import TrueSkill
4
+ import paramiko
5
+ import io, os
6
+ import sys
7
+ import random
8
+
9
+ sys.path.append('../')
10
+ from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
11
+ trueskill_env = TrueSkill()
12
+
13
+ ssh_matchmaker_client = None
14
+ sftp_matchmaker_client = None
15
+
16
+ def create_ssh_matchmaker_client(server, port, user, password):
17
+ global ssh_matchmaker_client, sftp_matchmaker_client
18
+ ssh_matchmaker_client = paramiko.SSHClient()
19
+ ssh_matchmaker_client.load_system_host_keys()
20
+ ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
21
+ ssh_matchmaker_client.connect(server, port, user, password)
22
+
23
+ transport = ssh_matchmaker_client.get_transport()
24
+ transport.set_keepalive(60)
25
+
26
+ sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
27
+
28
+
29
+ def is_connected():
30
+ global ssh_matchmaker_client, sftp_matchmaker_client
31
+ if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
32
+ return False
33
+ if not ssh_matchmaker_client.get_transport().is_active():
34
+ return False
35
+ try:
36
+ sftp_matchmaker_client.listdir('.')
37
+ except Exception as e:
38
+ print(f"Error checking SFTP connection: {e}")
39
+ return False
40
+ return True
41
+
42
+
43
+ def ucb_score(trueskill_diff, t, n):
44
+ exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
45
+ ucb = -trueskill_diff + 1.0 * exploration_term
46
+ return ucb
47
+
48
+
49
+ def update_trueskill(ratings, ranks):
50
+ new_ratings = trueskill_env.rate(ratings, ranks)
51
+ return new_ratings
52
+
53
+
54
+ def serialize_rating(rating):
55
+ return {'mu': rating.mu, 'sigma': rating.sigma}
56
+
57
+
58
+ def deserialize_rating(rating_dict):
59
+ return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
60
+
61
+
62
+ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
63
+ global sftp_matchmaker_client
64
+ if not is_connected():
65
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
66
+ data = {
67
+ 'ratings': [serialize_rating(r) for r in ratings],
68
+ 'comparison_counts': comparison_counts.tolist(),
69
+ 'total_comparisons': total_comparisons
70
+ }
71
+ json_data = json.dumps(data)
72
+ with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
73
+ f.write(json_data)
74
+
75
+
76
+ def load_json_via_sftp():
77
+ global sftp_matchmaker_client
78
+ if not is_connected():
79
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
80
+ with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
81
+ data = json.load(f)
82
+ ratings = [deserialize_rating(r) for r in data['ratings']]
83
+ comparison_counts = np.array(data['comparison_counts'])
84
+ total_comparisons = data['total_comparisons']
85
+ return ratings, comparison_counts, total_comparisons
86
+
87
+
88
+ class RunningPivot(object):
89
+ running_pivot = []
90
+
91
+
92
+ def matchmaker(num_players, k_group=4, not_run=[]):
93
+ trueskill_env = TrueSkill()
94
+
95
+ ratings, comparison_counts, total_comparisons = load_json_via_sftp()
96
+
97
+ ratings = ratings[:num_players]
98
+ comparison_counts = comparison_counts[:num_players, :num_players]
99
+
100
+ # Randomly select a player
101
+ # selected_player = np.random.randint(0, num_players)
102
+ comparison_counts[RunningPivot.running_pivot, :] = float('inf')
103
+ comparison_counts[not_run, :] = float('inf')
104
+ selected_player = np.argmin(comparison_counts.sum(axis=1))
105
+
106
+ RunningPivot.running_pivot.append(selected_player)
107
+ RunningPivot.running_pivot = RunningPivot.running_pivot[-5:]
108
+ print(RunningPivot.running_pivot)
109
+
110
+ selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
111
+ trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
112
+ trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
113
+ n = comparison_counts[selected_player]
114
+ ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
115
+
116
+ # Exclude self, select opponent with highest UCB score
117
+ ucb_scores[selected_player] = -float('inf')
118
+ ucb_scores[not_run] = -float('inf')
119
+ opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
120
+
121
+ # Group players
122
+ model_ids = [selected_player] + opponents
123
+
124
+ random.shuffle(model_ids)
125
+
126
+ return model_ids
model/matchmaker_video.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ from trueskill import TrueSkill
4
+ import paramiko
5
+ import io, os
6
+ import sys
7
+ import random
8
+
9
+ sys.path.append('../')
10
+ from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL
11
+ trueskill_env = TrueSkill()
12
+
13
+ ssh_matchmaker_client = None
14
+ sftp_matchmaker_client = None
15
+
16
+
17
+ def create_ssh_matchmaker_client(server, port, user, password):
18
+ global ssh_matchmaker_client, sftp_matchmaker_client
19
+ ssh_matchmaker_client = paramiko.SSHClient()
20
+ ssh_matchmaker_client.load_system_host_keys()
21
+ ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
22
+ ssh_matchmaker_client.connect(server, port, user, password)
23
+
24
+ transport = ssh_matchmaker_client.get_transport()
25
+ transport.set_keepalive(60)
26
+
27
+ sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
28
+
29
+
30
+ def is_connected():
31
+ global ssh_matchmaker_client, sftp_matchmaker_client
32
+ if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
33
+ return False
34
+ if not ssh_matchmaker_client.get_transport().is_active():
35
+ return False
36
+ try:
37
+ sftp_matchmaker_client.listdir('.')
38
+ except Exception as e:
39
+ print(f"Error checking SFTP connection: {e}")
40
+ return False
41
+ return True
42
+
43
+
44
+ def ucb_score(trueskill_diff, t, n):
45
+ exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
46
+ ucb = -trueskill_diff + 1.0 * exploration_term
47
+ return ucb
48
+
49
+
50
+ def update_trueskill(ratings, ranks):
51
+ new_ratings = trueskill_env.rate(ratings, ranks)
52
+ return new_ratings
53
+
54
+
55
+ def serialize_rating(rating):
56
+ return {'mu': rating.mu, 'sigma': rating.sigma}
57
+
58
+
59
+ def deserialize_rating(rating_dict):
60
+ return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
61
+
62
+
63
+ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
64
+ global sftp_matchmaker_client
65
+ if not is_connected():
66
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
67
+ data = {
68
+ 'ratings': [serialize_rating(r) for r in ratings],
69
+ 'comparison_counts': comparison_counts.tolist(),
70
+ 'total_comparisons': total_comparisons
71
+ }
72
+ json_data = json.dumps(data)
73
+ with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
74
+ f.write(json_data)
75
+
76
+
77
+ def load_json_via_sftp():
78
+ global sftp_matchmaker_client
79
+ if not is_connected():
80
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
81
+ with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'r') as f:
82
+ data = json.load(f)
83
+ ratings = [deserialize_rating(r) for r in data['ratings']]
84
+ comparison_counts = np.array(data['comparison_counts'])
85
+ total_comparisons = data['total_comparisons']
86
+ return ratings, comparison_counts, total_comparisons
87
+
88
+
89
+ def matchmaker_video(num_players, k_group=4):
90
+ trueskill_env = TrueSkill()
91
+
92
+ ratings, comparison_counts, total_comparisons = load_json_via_sftp()
93
+
94
+ ratings = ratings[:num_players]
95
+ comparison_counts = comparison_counts[:num_players, :num_players]
96
+
97
+ selected_player = np.argmin(comparison_counts.sum(axis=1))
98
+
99
+ selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
100
+ trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
101
+ trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
102
+ n = comparison_counts[selected_player]
103
+ ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
104
+
105
+ # Exclude self, select opponent with highest UCB score
106
+ ucb_scores[selected_player] = -float('inf')
107
+
108
+ excluded_players_1 = [7, 10]
109
+ excluded_players_2 = [6, 8, 9]
110
+ excluded_players = excluded_players_1 + excluded_players_2
111
+ if selected_player in excluded_players_1:
112
+ for player in excluded_players:
113
+ ucb_scores[player] = -float('inf')
114
+ if selected_player in excluded_players_2:
115
+ for player in excluded_players_1:
116
+ ucb_scores[player] = -float('inf')
117
+ else:
118
+ excluded_ucb_scores = {player: ucb_scores[player] for player in excluded_players}
119
+ max_player = max(excluded_ucb_scores, key=excluded_ucb_scores.get)
120
+ if max_player in excluded_players_1:
121
+ for player in excluded_players:
122
+ if player != max_player:
123
+ ucb_scores[player] = -float('inf')
124
+ else:
125
+ for player in excluded_players_1:
126
+ ucb_scores[player] = -float('inf')
127
+
128
+
129
+ opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
130
+
131
+ # Group players
132
+ model_ids = [selected_player] + opponents
133
+
134
+ random.shuffle(model_ids)
135
+
136
+ return model_ids
model/model_manager.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import random
3
+ import gradio as gr
4
+ import requests, os
5
+ import io, base64, json
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
+ from openai import OpenAI
10
+ from .models import IMAGE_GENERATION_MODELS, VIDEO_GENERATION_MODELS, load_pipeline
11
+ from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt
12
+ from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
13
+
14
+
15
+ class ModelManager:
16
+ def __init__(self):
17
+ self.model_ig_list = IMAGE_GENERATION_MODELS
18
+ self.model_ie_list = [] #IMAGE_EDITION_MODELS
19
+ self.model_vg_list = VIDEO_GENERATION_MODELS
20
+ self.model_b2i_list = []
21
+ self.loaded_models = {}
22
+
23
+ def load_model_pipe(self, model_name):
24
+ if not model_name in self.loaded_models:
25
+ pipe = load_pipeline(model_name)
26
+ self.loaded_models[model_name] = pipe
27
+ else:
28
+ pipe = self.loaded_models[model_name]
29
+ return pipe
30
+
31
+ @spaces.GPU(duration=120)
32
+ def generate_image_ig(self, prompt, model_name):
33
+ pipe = self.load_model_pipe(model_name)
34
+ if 'Stable-cascade' not in model_name:
35
+ result = pipe(prompt=prompt).images[0]
36
+ else:
37
+ prior, decoder = pipe
38
+ prior.enable_model_cpu_offload()
39
+ prior_output = prior(
40
+ prompt=prompt,
41
+ height=512,
42
+ width=512,
43
+ negative_prompt='',
44
+ guidance_scale=4.0,
45
+ num_images_per_prompt=1,
46
+ num_inference_steps=20
47
+ )
48
+ decoder.enable_model_cpu_offload()
49
+ result = decoder(
50
+ image_embeddings=prior_output.image_embeddings.to(torch.float16),
51
+ prompt=prompt,
52
+ negative_prompt='',
53
+ guidance_scale=0.0,
54
+ output_type="pil",
55
+ num_inference_steps=10
56
+ ).images[0]
57
+ return result
58
+
59
+ def generate_image_ig_api(self, prompt, model_name):
60
+ pipe = self.load_model_pipe(model_name)
61
+ result = pipe(prompt=prompt)
62
+ return result
63
+
64
+ def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
65
+ if model_A == "" and model_B == "" and model_C == "" and model_D == "":
66
+ from .matchmaker import matchmaker
67
+ not_run = [20,21,22, 25,26, 30] #12,13,14,15,16,17,18,19,20,21,22, #23,24,
68
+ model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
69
+ print(model_ids)
70
+ model_names = [self.model_ig_list[i] for i in model_ids]
71
+ print(model_names)
72
+ else:
73
+ model_names = [model_A, model_B, model_C, model_D]
74
+
75
+ with concurrent.futures.ThreadPoolExecutor() as executor:
76
+ futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
77
+ else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
78
+ results = [future.result() for future in futures]
79
+
80
+ return results[0], results[1], results[2], results[3], \
81
+ model_names[0], model_names[1], model_names[2], model_names[3]
82
+
83
+ def generate_image_ig_cache_anony(self, model_A, model_B, model_C, model_D):
84
+ if model_A == "" and model_B == "" and model_C == "" and model_D == "":
85
+ from .matchmaker import matchmaker
86
+ not_run = [20,21,22]
87
+ model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
88
+ print(model_ids)
89
+ model_names = [self.model_ig_list[i] for i in model_ids]
90
+ print(model_names)
91
+ else:
92
+ model_names = [model_A, model_B, model_C, model_D]
93
+
94
+ root_dir = SSH_CACHE_IMAGE
95
+ local_dir = "./cache_image"
96
+ if not os.path.exists(local_dir):
97
+ os.makedirs(local_dir)
98
+ prompt, results = get_ssh_random_image_prompt(root_dir, local_dir, model_names)
99
+
100
+ return results[0], results[1], results[2], results[3], \
101
+ model_names[0], model_names[1], model_names[2], model_names[3], prompt
102
+
103
+ def generate_video_vg_parallel_anony(self, model_A, model_B, model_C, model_D):
104
+ if model_A == "" and model_B == "" and model_C == "" and model_D == "":
105
+ # model_names = random.sample([model for model in self.model_vg_list], 4)
106
+
107
+ from .matchmaker_video import matchmaker_video
108
+ model_ids = matchmaker_video(num_players=len(self.model_vg_list))
109
+ print(model_ids)
110
+ model_names = [self.model_vg_list[i] for i in model_ids]
111
+ print(model_names)
112
+ else:
113
+ model_names = [model_A, model_B, model_C, model_D]
114
+
115
+ root_dir = SSH_CACHE_OPENSOURCE
116
+ for name in model_names:
117
+ if "Runway-Gen3" in name or "Runway-Gen2" in name or "Pika-v1.0" in name:
118
+ root_dir = SSH_CACHE_ADVANCE
119
+ elif "Pika-beta" in name:
120
+ root_dir = SSH_CACHE_PIKA
121
+ elif "Sora" in name and "OpenSora" not in name:
122
+ root_dir = SSH_CACHE_SORA
123
+
124
+ local_dir = "./cache_video"
125
+ if not os.path.exists(local_dir):
126
+ os.makedirs(local_dir)
127
+ prompt, results = get_ssh_random_video_prompt(root_dir, local_dir, model_names)
128
+ cache_dir = local_dir
129
+
130
+ return results[0], results[1], results[2], results[3], \
131
+ model_names[0], model_names[1], model_names[2], model_names[3], prompt, cache_dir
132
+
133
+ def generate_image_ig_museum_parallel_anony(self, model_A, model_B, model_C, model_D):
134
+ if model_A == "" and model_B == "" and model_C == "" and model_D == "":
135
+ # model_names = random.sample([model for model in self.model_ig_list], 4)
136
+
137
+ from .matchmaker import matchmaker
138
+ model_ids = matchmaker(num_players=len(self.model_ig_list))
139
+ print(model_ids)
140
+ model_names = [self.model_ig_list[i] for i in model_ids]
141
+ print(model_names)
142
+ else:
143
+ model_names = [model_A, model_B, model_C, model_D]
144
+
145
+ prompt = get_random_mscoco_prompt()
146
+ print(prompt)
147
+
148
+ with concurrent.futures.ThreadPoolExecutor() as executor:
149
+ futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
150
+ else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
151
+ results = [future.result() for future in futures]
152
+
153
+ return results[0], results[1], results[2], results[3], \
154
+ model_names[0], model_names[1], model_names[2], model_names[3], prompt
155
+
156
+ def generate_image_ig_parallel(self, prompt, model_A, model_B):
157
+ model_names = [model_A, model_B]
158
+ with concurrent.futures.ThreadPoolExecutor() as executor:
159
+ futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub")
160
+ else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
161
+ results = [future.result() for future in futures]
162
+ return results[0], results[1]
163
+
164
+ @spaces.GPU(duration=200)
165
+ def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
166
+ pipe = self.load_model_pipe(model_name)
167
+ result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
168
+ return result
169
+
170
+ def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
171
+ model_names = [model_A, model_B]
172
+ with concurrent.futures.ThreadPoolExecutor() as executor:
173
+ futures = [
174
+ executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image,
175
+ model) for model in model_names]
176
+ results = [future.result() for future in futures]
177
+ return results[0], results[1]
178
+
179
+ def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
180
+ if model_A == "" and model_B == "":
181
+ model_names = random.sample([model for model in self.model_ie_list], 2)
182
+ else:
183
+ model_names = [model_A, model_B]
184
+ with concurrent.futures.ThreadPoolExecutor() as executor:
185
+ futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names]
186
+ results = [future.result() for future in futures]
187
+ return results[0], results[1], model_names[0], model_names[1]
model/model_registry.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from typing import List
3
+
4
+ ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
5
+ model_info = {}
6
+
7
+ def register_model_info(
8
+ full_names: List[str], simple_name: str, link: str, description: str
9
+ ):
10
+ info = ModelInfo(simple_name, link, description)
11
+
12
+ for full_name in full_names:
13
+ model_info[full_name] = info
14
+
15
+ def get_model_info(name: str) -> ModelInfo:
16
+ if name in model_info:
17
+ return model_info[name]
18
+ else:
19
+ # To fix this, please use `register_model_info` to register your model
20
+ return ModelInfo(
21
+ name, "", "Register the description at fastchat/model/model_registry.py"
22
+ )
23
+
24
+ def get_model_description_md(model_list):
25
+ model_description_md = """
26
+ | | | | | | | | | | | |
27
+ | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
28
+ """
29
+ ct = 0
30
+ visited = set()
31
+ for i, name in enumerate(model_list):
32
+ model_source, model_name, model_type = name.split("_")
33
+ minfo = get_model_info(model_name)
34
+ if minfo.simple_name in visited:
35
+ continue
36
+ visited.add(minfo.simple_name)
37
+ # one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
38
+ one_model_md = f"{minfo.simple_name}"
39
+
40
+ if ct % 11 == 0:
41
+ model_description_md += "|"
42
+ model_description_md += f" {one_model_md} |"
43
+ if ct % 11 == 10:
44
+ model_description_md += "\n"
45
+ ct += 1
46
+ return model_description_md
47
+
48
+ def get_video_model_description_md(model_list):
49
+ model_description_md = """
50
+ | | | | | | |
51
+ | ---- | ---- | ---- | ---- | ---- | ---- |
52
+ """
53
+ ct = 0
54
+ visited = set()
55
+ for i, name in enumerate(model_list):
56
+ model_source, model_name, model_type = name.split("_")
57
+ minfo = get_model_info(model_name)
58
+ if minfo.simple_name in visited:
59
+ continue
60
+ visited.add(minfo.simple_name)
61
+ # one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
62
+ one_model_md = f"{minfo.simple_name}"
63
+
64
+ if ct % 7 == 0:
65
+ model_description_md += "|"
66
+ model_description_md += f" {one_model_md} |"
67
+ if ct % 7 == 6:
68
+ model_description_md += "\n"
69
+ ct += 1
70
+ return model_description_md
model/models/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .huggingface_models import load_huggingface_model
2
+ from .replicate_api_models import load_replicate_model
3
+ from .openai_api_models import load_openai_model
4
+ from .other_api_models import load_other_model
5
+
6
+
7
+ IMAGE_GENERATION_MODELS = [
8
+ 'replicate_SDXL_text2image',
9
+ 'replicate_SD-v3.0_text2image',
10
+ 'replicate_SD-v2.1_text2image',
11
+ 'replicate_SD-v1.5_text2image',
12
+ 'replicate_SDXL-Lightning_text2image',
13
+ 'replicate_Kandinsky-v2.0_text2image',
14
+ 'replicate_Kandinsky-v2.2_text2image',
15
+ 'replicate_Proteus-v0.2_text2image',
16
+ 'replicate_Playground-v2.0_text2image',
17
+ 'replicate_Playground-v2.5_text2image',
18
+ 'replicate_Dreamshaper-xl-turbo_text2image',
19
+ 'replicate_SDXL-Deepcache_text2image',
20
+ 'replicate_Openjourney-v4_text2image',
21
+ 'replicate_LCM-v1.5_text2image',
22
+ 'replicate_Realvisxl-v3.0_text2image',
23
+ 'replicate_Realvisxl-v2.0_text2image',
24
+ 'replicate_Pixart-Sigma_text2image',
25
+ 'replicate_SSD-1b_text2image',
26
+ 'replicate_Open-Dalle-v1.1_text2image',
27
+ 'replicate_Deepfloyd-IF_text2image',
28
+ 'huggingface_SD-turbo_text2image',
29
+ 'huggingface_SDXL-turbo_text2image',
30
+ 'huggingface_Stable-cascade_text2image',
31
+ 'openai_Dalle-2_text2image',
32
+ 'openai_Dalle-3_text2image',
33
+ 'other_Midjourney-v6.0_text2image',
34
+ 'other_Midjourney-v5.0_text2image',
35
+ "replicate_FLUX.1-schnell_text2image",
36
+ "replicate_FLUX.1-pro_text2image",
37
+ "replicate_FLUX.1-dev_text2image",
38
+ 'other_Meissonic_text2image',
39
+ "replicate_FLUX-1.1-pro_text2image",
40
+ 'replicate_SD-v3.5-large_text2image',
41
+ 'replicate_SD-v3.5-large-turbo_text2image',
42
+ ]
43
+
44
+ VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
45
+ 'replicate_Animate-Diff_text2video',
46
+ 'replicate_OpenSora_text2video',
47
+ 'replicate_LaVie_text2video',
48
+ 'replicate_VideoCrafter2_text2video',
49
+ 'replicate_Stable-Video-Diffusion_text2video',
50
+ 'other_Runway-Gen3_text2video',
51
+ 'other_Pika-beta_text2video',
52
+ 'other_Pika-v1.0_text2video',
53
+ 'other_Runway-Gen2_text2video',
54
+ 'other_Sora_text2video',
55
+ 'replicate_Cogvideox-5b_text2video',
56
+ 'other_KLing-v1.0_text2video',
57
+ ]
58
+
59
+
60
+ def load_pipeline(model_name):
61
+ """
62
+ Load a model pipeline based on the model name
63
+ Args:
64
+ model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
65
+ """
66
+ model_source, model_name, model_type = model_name.split("_")
67
+
68
+ if model_source == "replicate":
69
+ pipe = load_replicate_model(model_name, model_type)
70
+ elif model_source == "huggingface":
71
+ pipe = load_huggingface_model(model_name, model_type)
72
+ elif model_source == "openai":
73
+ pipe = load_openai_model(model_name, model_type)
74
+ elif model_source == "other":
75
+ pipe = load_other_model(model_name, model_type)
76
+ else:
77
+ raise ValueError(f"Model source {model_source} not supported")
78
+ return pipe
model/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
model/models/__pycache__/huggingface_models.cpython-310.pyc ADDED
Binary file (1.48 kB). View file
 
model/models/__pycache__/openai_api_models.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
model/models/__pycache__/other_api_models.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
model/models/__pycache__/replicate_api_models.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
model/models/huggingface_models.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ from diffusers import AutoPipelineForText2Image
3
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
4
+ import torch
5
+
6
+
7
+ def load_huggingface_model(model_name, model_type):
8
+ if model_name == "SD-turbo":
9
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
10
+ pipe = pipe.to("cuda")
11
+ elif model_name == "SDXL-turbo":
12
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
13
+ pipe = pipe.to("cuda")
14
+ elif model_name == "Stable-cascade":
15
+ prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
16
+ decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
17
+ pipe = [prior, decoder]
18
+ else:
19
+ raise NotImplementedError
20
+ # if model_name == "SD-turbo":
21
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
22
+ # elif model_name == "SDXL-turbo":
23
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
24
+ # else:
25
+ # raise NotImplementedError
26
+ # pipe = pipe.to("cpu")
27
+ return pipe
28
+
29
+
30
+ if __name__ == "__main__":
31
+ # for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
32
+ # pipe = load_huggingface_model(name, "text2image")
33
+
34
+ # for name in ["IF-I-XL-v1.0"]:
35
+ # pipe = load_huggingface_model(name, 'text2image')
36
+ # pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
37
+
38
+ prompt = 'draw a tiger'
39
+ pipe = load_huggingface_model('Stable-cascade', "text2image")
40
+ prior, decoder = pipe
41
+ prior.enable_model_cpu_offload()
42
+ prior_output = prior(
43
+ prompt=prompt,
44
+ height=512,
45
+ width=512,
46
+ negative_prompt='',
47
+ guidance_scale=4.0,
48
+ num_images_per_prompt=1,
49
+ num_inference_steps=20
50
+ )
51
+ decoder.enable_model_cpu_offload()
52
+ result = decoder(
53
+ image_embeddings=prior_output.image_embeddings.to(torch.float16),
54
+ prompt=prompt,
55
+ negative_prompt='',
56
+ guidance_scale=0.0,
57
+ output_type="pil",
58
+ num_inference_steps=10
59
+ ).images[0]
model/models/openai_api_models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from PIL import Image
3
+ import requests
4
+ import io
5
+ import os
6
+ import base64
7
+
8
+
9
+ class OpenaiModel():
10
+ def __init__(self, model_name, model_type):
11
+ self.model_name = model_name
12
+ self.model_type = model_type
13
+
14
+ def __call__(self, *args, **kwargs):
15
+ if self.model_type == "text2image":
16
+ assert "prompt" in kwargs, "prompt is required for text2image model"
17
+
18
+ client = OpenAI()
19
+
20
+ if 'Dalle-3' in self.model_name:
21
+ client = OpenAI()
22
+ response = client.images.generate(
23
+ model="dall-e-3",
24
+ prompt=kwargs["prompt"],
25
+ size="1024x1024",
26
+ quality="standard",
27
+ n=1,
28
+ )
29
+ elif 'Dalle-2' in self.model_name:
30
+ client = OpenAI()
31
+ response = client.images.generate(
32
+ model="dall-e-2",
33
+ prompt=kwargs["prompt"],
34
+ size="512x512",
35
+ quality="standard",
36
+ n=1,
37
+ )
38
+ else:
39
+ raise NotImplementedError
40
+
41
+ result_url = response.data[0].url
42
+ response = requests.get(result_url)
43
+ result = Image.open(io.BytesIO(response.content))
44
+ return result
45
+ else:
46
+ raise ValueError("model_type must be text2image or image2image")
47
+
48
+
49
+ def load_openai_model(model_name, model_type):
50
+ return OpenaiModel(model_name, model_type)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ pipe = load_openai_model('Dalle-3', 'text2image')
55
+ result = pipe(prompt='draw a tiger')
56
+ print(result)
57
+
model/models/other_api_models.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ from PIL import Image
5
+ import io, time
6
+
7
+
8
+ class OtherModel():
9
+ def __init__(self, model_name, model_type):
10
+ self.model_name = model_name
11
+ self.model_type = model_type
12
+ self.image_url = "https://www.xdai.online/mj/submit/imagine"
13
+ self.key = os.environ.get('MIDJOURNEY_KEY')
14
+ self.get_image_url = "https://www.xdai.online/mj/image/"
15
+ self.repeat_num = 5
16
+
17
+ def __call__(self, *args, **kwargs):
18
+ if self.model_type == "text2image":
19
+ assert "prompt" in kwargs, "prompt is required for text2image model"
20
+ if self.model_name == "Midjourney-v6.0":
21
+ data = {
22
+ "base64Array": [],
23
+ "notifyHook": "",
24
+ "prompt": "{} --v 6.0".format(kwargs["prompt"]),
25
+ "state": "",
26
+ "botType": "MID_JOURNEY",
27
+ }
28
+ elif self.model_name == "Midjourney-v5.0":
29
+ data = {
30
+ "base64Array": [],
31
+ "notifyHook": "",
32
+ "prompt": "{} --v 5.0".format(kwargs["prompt"]),
33
+ "state": "",
34
+ "botType": "MID_JOURNEY",
35
+ }
36
+ else:
37
+ raise NotImplementedError
38
+
39
+ headers = {
40
+ "Authorization": "Bearer {}".format(self.key),
41
+ "Content-Type": "application/json"
42
+ }
43
+ while 1:
44
+ response = requests.post(self.image_url, data=json.dumps(data), headers=headers)
45
+ if response.status_code == 200:
46
+ print("Submit success!")
47
+ response_json = json.loads(response.content.decode('utf-8'))
48
+ img_id = response_json["result"]
49
+ result_url = self.get_image_url + img_id
50
+ print(result_url)
51
+ self.repeat_num = 800
52
+ while 1:
53
+ time.sleep(1)
54
+ img_response = requests.get(result_url)
55
+ if img_response.status_code == 200:
56
+ result = Image.open(io.BytesIO(img_response.content))
57
+ width, height = result.size
58
+ new_width = width // 2
59
+ new_height = height // 2
60
+ result = result.crop((0, 0, new_width, new_height))
61
+ self.repeat_num = 5
62
+ return result
63
+ else:
64
+ self.repeat_num = self.repeat_num - 1
65
+ if self.repeat_num == 0:
66
+ raise ValueError("Image request failed.")
67
+ continue
68
+
69
+ else:
70
+ self.repeat_num = self.repeat_num - 1
71
+ if self.repeat_num == 0:
72
+ raise ValueError("API request failed.")
73
+ continue
74
+ if self.model_type == "text2video":
75
+ assert "prompt" in kwargs, "prompt is required for text2video model"
76
+
77
+ else:
78
+ raise ValueError("model_type must be text2image")
79
+
80
+
81
+ def load_other_model(model_name, model_type):
82
+ return OtherModel(model_name, model_type)
83
+
84
+ if __name__ == "__main__":
85
+ import http.client
86
+ import json
87
+
88
+ pipe = load_other_model("Midjourney-v5.0", "text2image")
89
+ result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow")
90
+ print(result)
91
+ exit()
model/models/replicate_api_models.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ from PIL import Image
3
+ import requests
4
+ import io
5
+ import os
6
+ import base64
7
+
8
+ Replicate_MODEl_NAME_MAP = {
9
+ "SDXL": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
10
+ "SD-v3.0": "stability-ai/stable-diffusion-3",
11
+ "SD-v2.1": "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
12
+ "SD-v1.5": "stability-ai/stable-diffusion:b3d14e1cd1f9470bbb0bb68cac48e5f483e5be309551992cc33dc30654a82bb7",
13
+ "SDXL-Lightning": "bytedance/sdxl-lightning-4step:5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
14
+ "Kandinsky-v2.0": "ai-forever/kandinsky-2:3c6374e7a9a17e01afe306a5218cc67de55b19ea536466d6ea2602cfecea40a9",
15
+ "Kandinsky-v2.2": "ai-forever/kandinsky-2.2:ad9d7879fbffa2874e1d909d1d37d9bc682889cc65b31f7bb00d2362619f194a",
16
+ "Proteus-v0.2": "lucataco/proteus-v0.2:06775cd262843edbde5abab958abdbb65a0a6b58ca301c9fd78fa55c775fc019",
17
+ "Playground-v2.0": "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8",
18
+ "Playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
19
+ "Dreamshaper-xl-turbo": "lucataco/dreamshaper-xl-turbo:0a1710e0187b01a255302738ca0158ff02a22f4638679533e111082f9dd1b615",
20
+ "SDXL-Deepcache": "lucataco/sdxl-deepcache:eaf678fb34006669e9a3c6dd5971e2279bf20ee0adeced464d7b6d95de16dc93",
21
+ "Openjourney-v4": "prompthero/openjourney:ad59ca21177f9e217b9075e7300cf6e14f7e5b4505b87b9689dbd866e9768969",
22
+ "LCM-v1.5": "fofr/latent-consistency-model:683d19dc312f7a9f0428b04429a9ccefd28dbf7785fef083ad5cf991b65f406f",
23
+ "Realvisxl-v3.0": "fofr/realvisxl-v3:33279060bbbb8858700eb2146350a98d96ef334fcf817f37eb05915e1534aa1c",
24
+
25
+ "Realvisxl-v2.0": "lucataco/realvisxl-v2.0:7d6a2f9c4754477b12c14ed2a58f89bb85128edcdd581d24ce58b6926029de08",
26
+ "Pixart-Sigma": "cjwbw/pixart-sigma:5a54352c99d9fef467986bc8f3a20205e8712cbd3df1cbae4975d6254c902de1",
27
+ "SSD-1b": "lucataco/ssd-1b:b19e3639452c59ce8295b82aba70a231404cb062f2eb580ea894b31e8ce5bbb6",
28
+ "Open-Dalle-v1.1": "lucataco/open-dalle-v1.1:1c7d4c8dec39c7306df7794b28419078cb9d18b9213ab1c21fdc46a1deca0144",
29
+ "Deepfloyd-IF": "andreasjansson/deepfloyd-if:fb84d659df149f4515c351e394d22222a94144aa1403870c36025c8b28846c8d",
30
+
31
+ "Zeroscope-v2-xl": "anotherjesse/zeroscope-v2-xl:9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351",
32
+ # "Damo-Text-to-Video": "cjwbw/damo-text-to-video:1e205ea73084bd17a0a3b43396e49ba0d6bc2e754e9283b2df49fad2dcf95755",
33
+ "Animate-Diff": "lucataco/animate-diff:beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f",
34
+ "OpenSora": "camenduru/open-sora:8099e5722ba3d5f408cd3e696e6df058137056268939337a3fbe3912e86e72ad",
35
+ "LaVie": "cjwbw/lavie:0bca850c4928b6c30052541fa002f24cbb4b677259c461dd041d271ba9d3c517",
36
+ "VideoCrafter2": "lucataco/video-crafter:7757c5775e962c618053e7df4343052a21075676d6234e8ede5fa67c9e43bce0",
37
+ "Stable-Video-Diffusion": "sunfjun/stable-video-diffusion:d68b6e09eedbac7a49e3d8644999d93579c386a083768235cabca88796d70d82",
38
+ "FLUX.1-schnell": "black-forest-labs/flux-schnell",
39
+ "FLUX.1-pro": "black-forest-labs/flux-pro",
40
+ "FLUX.1-dev": "black-forest-labs/flux-dev",
41
+ "FLUX-1.1-pro": "black-forest-labs/flux-1.1-pro",
42
+ "SD-v3.5-large": "stability-ai/stable-diffusion-3.5-large",
43
+ "SD-v3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
44
+ }
45
+
46
+
47
+ class ReplicateModel():
48
+ def __init__(self, model_name, model_type):
49
+ self.model_name = model_name
50
+ self.model_type = model_type
51
+
52
+ def __call__(self, *args, **kwargs):
53
+ if self.model_type == "text2image":
54
+ assert "prompt" in kwargs, "prompt is required for text2image model"
55
+ output = replicate.run(
56
+ f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
57
+ input={
58
+ "width": 512,
59
+ "height": 512,
60
+ "prompt": kwargs["prompt"]
61
+ },
62
+ )
63
+ if 'Openjourney' in self.model_name:
64
+ for item in output:
65
+ result_url = item
66
+ break
67
+ elif isinstance(output, list):
68
+ result_url = output[0]
69
+ else:
70
+ result_url = output
71
+ print(self.model_name, result_url)
72
+ response = requests.get(result_url)
73
+ result = Image.open(io.BytesIO(response.content))
74
+ return result
75
+
76
+ elif self.model_type == "text2video":
77
+ assert "prompt" in kwargs, "prompt is required for text2image model"
78
+ if self.model_name == "Zeroscope-v2-xl":
79
+ input = {
80
+ "fps": 24,
81
+ "width": 512,
82
+ "height": 512,
83
+ "prompt": kwargs["prompt"],
84
+ "guidance_scale": 17.5,
85
+ # "negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken",
86
+ "num_frames": 48,
87
+ }
88
+ elif self.model_name == "Damo-Text-to-Video":
89
+ input={
90
+ "fps": 8,
91
+ "prompt": kwargs["prompt"],
92
+ "num_frames": 16,
93
+ "num_inference_steps": 50
94
+ }
95
+ elif self.model_name == "Animate-Diff":
96
+ input={
97
+ "path": "toonyou_beta3.safetensors",
98
+ "seed": 255224557,
99
+ "steps": 25,
100
+ "prompt": kwargs["prompt"],
101
+ "n_prompt": "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth",
102
+ "motion_module": "mm_sd_v14",
103
+ "guidance_scale": 7.5
104
+ }
105
+ elif self.model_name == "OpenSora":
106
+ input={
107
+ "seed": 1234,
108
+ "prompt": kwargs["prompt"],
109
+ }
110
+ elif self.model_name == "LaVie":
111
+ input={
112
+ "width": 512,
113
+ "height": 512,
114
+ "prompt": kwargs["prompt"],
115
+ "quality": 9,
116
+ "video_fps": 8,
117
+ "interpolation": False,
118
+ "sample_method": "ddpm",
119
+ "guidance_scale": 7,
120
+ "super_resolution": False,
121
+ "num_inference_steps": 50
122
+ }
123
+ elif self.model_name == "VideoCrafter2":
124
+ input={
125
+ "fps": 24,
126
+ "seed": 64045,
127
+ "steps": 40,
128
+ "width": 512,
129
+ "height": 512,
130
+ "prompt": kwargs["prompt"],
131
+ }
132
+ elif self.model_name == "Stable-Video-Diffusion":
133
+ text2image_name = "SD-v2.1"
134
+ output = replicate.run(
135
+ f"{Replicate_MODEl_NAME_MAP[text2image_name]}",
136
+ input={
137
+ "width": 512,
138
+ "height": 512,
139
+ "prompt": kwargs["prompt"]
140
+ },
141
+ )
142
+ if isinstance(output, list):
143
+ image_url = output[0]
144
+ else:
145
+ image_url = output
146
+ print(image_url)
147
+
148
+ input={
149
+ "cond_aug": 0.02,
150
+ "decoding_t": 14,
151
+ "input_image": "{}".format(image_url),
152
+ "video_length": "14_frames_with_svd",
153
+ "sizing_strategy": "maintain_aspect_ratio",
154
+ "motion_bucket_id": 127,
155
+ "frames_per_second": 6
156
+ }
157
+
158
+ output = replicate.run(
159
+ f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
160
+ input=input,
161
+ )
162
+ if isinstance(output, list):
163
+ result_url = output[0]
164
+ else:
165
+ result_url = output
166
+ print(self.model_name)
167
+ print(result_url)
168
+ # response = requests.get(result_url)
169
+ # result = Image.open(io.BytesIO(response.content))
170
+
171
+ # for event in handler.iter_events(with_logs=True):
172
+ # if isinstance(event, fal_client.InProgress):
173
+ # print('Request in progress')
174
+ # print(event.logs)
175
+
176
+ # result = handler.get()
177
+ # print("result video: ====")
178
+ # print(result)
179
+ # result_url = result['video']['url']
180
+ # return result_url
181
+ return result_url
182
+ else:
183
+ raise ValueError("model_type must be text2image or image2image")
184
+
185
+
186
+ def load_replicate_model(model_name, model_type):
187
+ return ReplicateModel(model_name, model_type)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ model_name = 'replicate_zeroscope-v2-xl_text2video'
192
+ model_source, model_name, model_type = model_name.split("_")
193
+ pipe = load_replicate_model(model_name, model_type)
194
+ prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
195
+ result = pipe(prompt=prompt)
serve/Arial.ttf ADDED
Binary file (155 kB). View file
 
serve/Ksort.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
3
+ import os
4
+ from .constants import KSORT_IMAGE_DIR
5
+ from .constants import COLOR1, COLOR2, COLOR3, COLOR4
6
+ from .vote_utils import save_any_image
7
+ from .utils import disable_btn, enable_btn, invisible_btn
8
+ from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
9
+ import json
10
+
11
+
12
+ def reset_level(Top_btn):
13
+ if Top_btn == "Top 1":
14
+ level = 0
15
+ elif Top_btn == "Top 2":
16
+ level = 1
17
+ elif Top_btn == "Top 3":
18
+ level = 2
19
+ elif Top_btn == "Top 4":
20
+ level = 3
21
+ return level
22
+
23
+
24
+ def reset_rank(windows, rank, vote_level):
25
+ if windows == "Model A":
26
+ rank[0] = vote_level
27
+ elif windows == "Model B":
28
+ rank[1] = vote_level
29
+ elif windows == "Model C":
30
+ rank[2] = vote_level
31
+ elif windows == "Model D":
32
+ rank[3] = vote_level
33
+ return rank
34
+
35
+
36
+ def reset_btn_rank(windows, rank, btn, vote_level):
37
+ if windows == "Model A" and btn == "1":
38
+ rank[0] = 0
39
+ elif windows == "Model A" and btn == "2":
40
+ rank[0] = 1
41
+ elif windows == "Model A" and btn == "3":
42
+ rank[0] = 2
43
+ elif windows == "Model A" and btn == "4":
44
+ rank[0] = 3
45
+ elif windows == "Model B" and btn == "1":
46
+ rank[1] = 0
47
+ elif windows == "Model B" and btn == "2":
48
+ rank[1] = 1
49
+ elif windows == "Model B" and btn == "3":
50
+ rank[1] = 2
51
+ elif windows == "Model B" and btn == "4":
52
+ rank[1] = 3
53
+ elif windows == "Model C" and btn == "1":
54
+ rank[2] = 0
55
+ elif windows == "Model C" and btn == "2":
56
+ rank[2] = 1
57
+ elif windows == "Model C" and btn == "3":
58
+ rank[2] = 2
59
+ elif windows == "Model C" and btn == "4":
60
+ rank[2] = 3
61
+ elif windows == "Model D" and btn == "1":
62
+ rank[3] = 0
63
+ elif windows == "Model D" and btn == "2":
64
+ rank[3] = 1
65
+ elif windows == "Model D" and btn == "3":
66
+ rank[3] = 2
67
+ elif windows == "Model D" and btn == "4":
68
+ rank[3] = 3
69
+ if btn == "1":
70
+ vote_level = 0
71
+ elif btn == "2":
72
+ vote_level = 1
73
+ elif btn == "3":
74
+ vote_level = 2
75
+ elif btn == "4":
76
+ vote_level = 3
77
+ return (rank, vote_level)
78
+
79
+
80
+ def reset_vote_text(rank):
81
+ rank_str = ""
82
+ for i in range(len(rank)):
83
+ if rank[i] == None:
84
+ rank_str = rank_str + str(rank[i])
85
+ else:
86
+ rank_str = rank_str + str(rank[i]+1)
87
+ rank_str = rank_str + " "
88
+ return rank_str
89
+
90
+
91
+ def clear_rank(rank, vote_level):
92
+ for i in range(len(rank)):
93
+ rank[i] = None
94
+ vote_level = 0
95
+ return rank, vote_level
96
+
97
+
98
+ def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
99
+ for i in range(len(rank)):
100
+ rank[i] = None
101
+ vote_level = 0
102
+ return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
103
+
104
+
105
+ def reset_submit(rank):
106
+ for i in range(len(rank)):
107
+ if rank[i] == None:
108
+ return disable_btn
109
+ return enable_btn
110
+
111
+
112
+ def reset_mode(mode):
113
+
114
+ if mode == "Best":
115
+ return (gr.update(visible=False, interactive=False),) * 5 + \
116
+ (gr.update(visible=True, interactive=True),) * 16 + \
117
+ (gr.update(visible=True, interactive=True),) * 3 + \
118
+ (gr.Textbox(value="Rank", visible=False, interactive=False),)
119
+ elif mode == "Rank":
120
+ return (gr.update(visible=True, interactive=True),) * 5 + \
121
+ (gr.update(visible=False, interactive=False),) * 16 + \
122
+ (gr.update(visible=True, interactive=False),) * 2 + \
123
+ (gr.update(visible=True, interactive=True),) + \
124
+ (gr.Textbox(value="Best", visible=False, interactive=False),)
125
+ else:
126
+ raise ValueError("Undefined mode")
127
+
128
+
129
+ def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
130
+ return generate_ig0, generate_ig1, generate_ig2, generate_ig3
131
+
132
+
133
+ def get_json_filename(conv_id):
134
+ output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
135
+ if not os.path.exists(output_dir):
136
+ os.makedirs(output_dir)
137
+ output_file = os.path.join(output_dir, "information.json")
138
+ # name = os.path.join(KSORT_IMAGE_DIR, f"{conv_id}/json/information.json")
139
+ print(output_file)
140
+ return output_file
141
+
142
+
143
+ def get_img_filename(conv_id, i):
144
+ output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
145
+ if not os.path.exists(output_dir):
146
+ os.makedirs(output_dir)
147
+ output_file = os.path.join(output_dir, f"{i}.jpg")
148
+ print(output_file)
149
+ return output_file
150
+
151
+
152
+ def vote_submit(states, textbox, rank, request: gr.Request):
153
+ conv_id = states[0].conv_id
154
+
155
+ for i in range(len(states)):
156
+ output_file = get_img_filename(conv_id, i)
157
+ save_any_image(states[i].output, output_file)
158
+ with open(get_json_filename(conv_id), "a") as fout:
159
+ data = {
160
+ "models_name": [x.model_name for x in states],
161
+ "img_rank": [x for x in rank],
162
+ "prompt": [textbox],
163
+ }
164
+ fout.write(json.dumps(data) + "\n")
165
+
166
+
167
+ def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
168
+ conv_id = states[0].conv_id
169
+ output_dir = create_remote_directory(conv_id)
170
+ # upload_image(states, output_dir)
171
+
172
+ data = {
173
+ "models_name": [x.model_name for x in states],
174
+ "img_rank": [x for x in rank],
175
+ "prompt": [textbox],
176
+ "user_info": {"name": [user_name], "institution": [user_institution]},
177
+ }
178
+ output_file = os.path.join(output_dir, "result.json")
179
+ # upload_informance(data, output_file)
180
+ upload_ssh_all(states, output_dir, data, output_file)
181
+
182
+ from .update_skill import update_skill
183
+ update_skill(rank, [x.model_name for x in states])
184
+
185
+
186
+ def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
187
+ conv_id = states[0].conv_id
188
+ output_dir = create_remote_directory(conv_id, video=True)
189
+
190
+ data = {
191
+ "models_name": [x.model_name for x in states],
192
+ "video_rank": [x for x in rank],
193
+ "prompt": [textbox],
194
+ "prompt_path": [prompt_path],
195
+ "video_path": [x.output for x in states],
196
+ "user_info": {"name": [user_name], "institution": [user_institution]},
197
+ }
198
+ output_file = os.path.join(output_dir, "result.json")
199
+
200
+ upload_ssh_data(data, output_file)
201
+
202
+ from .update_skill_video import update_skill_video
203
+ update_skill_video(rank, [x.model_name for x in states])
204
+
205
+
206
+ def submit_response_igm(
207
+ state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
208
+ ):
209
+ # vote_submit([state0, state1, state2, state3], textbox, rank, request)
210
+ vote_ssh_submit([state0, state1, state2, state3], textbox, rank, user_name, user_institution)
211
+ if model_selector0 == "":
212
+ return (disable_btn,) * 6 + (
213
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
214
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
215
+ gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
216
+ gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
217
+ ) + (disable_btn,)
218
+ else:
219
+ return (disable_btn,) * 6 + (
220
+ gr.Markdown(state0.model_name, visible=True),
221
+ gr.Markdown(state1.model_name, visible=True),
222
+ gr.Markdown(state2.model_name, visible=True),
223
+ gr.Markdown(state3.model_name, visible=True)
224
+ ) + (disable_btn,)
225
+
226
+
227
+ def submit_response_vg(
228
+ state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
229
+ ):
230
+ vote_video_ssh_submit([state0, state1, state2, state3], textbox, prompt_path, rank, user_name, user_institution)
231
+ if model_selector0 == "":
232
+ return (disable_btn,) * 6 + (
233
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
234
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
235
+ gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
236
+ gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
237
+ ) + (disable_btn,)
238
+ else:
239
+ return (disable_btn,) * 6 + (
240
+ gr.Markdown(state0.model_name, visible=True),
241
+ gr.Markdown(state1.model_name, visible=True),
242
+ gr.Markdown(state2.model_name, visible=True),
243
+ gr.Markdown(state3.model_name, visible=True)
244
+ ) + (disable_btn,)
245
+
246
+
247
+ def submit_response_rank_igm(
248
+ state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
249
+ ):
250
+ print(rank)
251
+ if right_vote_text == "right":
252
+ # vote_submit([state0, state1, state2, state3], textbox, rank, request)
253
+ vote_ssh_submit([state0, state1, state2, state3], textbox, rank, user_name, user_institution)
254
+ if model_selector0 == "":
255
+ return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
256
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
257
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
258
+ gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
259
+ gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
260
+ )
261
+ else:
262
+ return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
263
+ gr.Markdown(state0.model_name, visible=True),
264
+ gr.Markdown(state1.model_name, visible=True),
265
+ gr.Markdown(state2.model_name, visible=True),
266
+ gr.Markdown(state3.model_name, visible=True)
267
+ )
268
+ else:
269
+ return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
270
+
271
+
272
+ def submit_response_rank_vg(
273
+ state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
274
+ ):
275
+ print(rank)
276
+ if right_vote_text == "right":
277
+ vote_video_ssh_submit([state0, state1, state2, state3], textbox, prompt_path, rank, user_name, user_institution)
278
+ if model_selector0 == "":
279
+ return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
280
+ gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
281
+ gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
282
+ gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
283
+ gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
284
+ )
285
+ else:
286
+ return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
287
+ gr.Markdown(state0.model_name, visible=True),
288
+ gr.Markdown(state1.model_name, visible=True),
289
+ gr.Markdown(state2.model_name, visible=True),
290
+ gr.Markdown(state3.model_name, visible=True)
291
+ )
292
+ else:
293
+ return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
294
+
295
+
296
+ def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
297
+ rank_list = [char for char in vote_textbox if char.isdigit()]
298
+ generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
299
+ chatbot = []
300
+ rank = [None, None, None, None]
301
+ if len(rank_list) != 4:
302
+ return generate_ig + ["error rank"] + ["wrong"] + [rank]
303
+ for num in range(len(rank_list)):
304
+ if rank_list[num] in ['1', '2', '3', '4']:
305
+ base_image = Image.fromarray(generate_ig[num]).convert("RGBA")
306
+ base_image = base_image.resize((512, 512), Image.ANTIALIAS)
307
+ if rank_list[num] == '1':
308
+ border_color = COLOR1
309
+ elif rank_list[num] == '2':
310
+ border_color = COLOR2
311
+ elif rank_list[num] == '3':
312
+ border_color = COLOR3
313
+ elif rank_list[num] == '4':
314
+ border_color = COLOR4
315
+ border_size = 10 # Size of the border
316
+ base_image = ImageOps.expand(base_image, border=border_size, fill=border_color)
317
+
318
+ draw = ImageDraw.Draw(base_image)
319
+ font = ImageFont.truetype("./serve/Arial.ttf", 66)
320
+ text_position = (180, 25)
321
+ if rank_list[num] == '1':
322
+ text_color = COLOR1
323
+ draw.text(text_position, Top1_text, font=font, fill=text_color)
324
+ elif rank_list[num] == '2':
325
+ text_color = COLOR2
326
+ draw.text(text_position, Top2_text, font=font, fill=text_color)
327
+ elif rank_list[num] == '3':
328
+ text_color = COLOR3
329
+ draw.text(text_position, Top3_text, font=font, fill=text_color)
330
+ elif rank_list[num] == '4':
331
+ text_color = COLOR4
332
+ draw.text(text_position, Top4_text, font=font, fill=text_color)
333
+ base_image = base_image.convert("RGB")
334
+ chatbot.append(base_image.copy())
335
+ else:
336
+ return generate_ig + ["error rank"] + ["wrong"] + [rank]
337
+ rank_str = ""
338
+ for str_num in rank_list:
339
+ rank_str = rank_str + str_num
340
+ rank_str = rank_str + " "
341
+ rank = [int(x) for x in rank_list]
342
+
343
+ return chatbot + [rank_str] + ["right"] + [rank]
344
+
345
+
346
+ def text_response_rank_vg(vote_textbox):
347
+ rank_list = [char for char in vote_textbox if char.isdigit()]
348
+ rank = [None, None, None, None]
349
+ if len(rank_list) != 4:
350
+ return ["error rank"] + ["wrong"] + [rank]
351
+ for num in range(len(rank_list)):
352
+ if rank_list[num] in ['1', '2', '3', '4']:
353
+ continue
354
+ else:
355
+ return ["error rank"] + ["wrong"] + [rank]
356
+ rank_str = ""
357
+ for str_num in rank_list:
358
+ rank_str = rank_str + str_num
359
+ rank_str = rank_str + " "
360
+ rank = [int(x) for x in rank_list]
361
+
362
+ return [rank_str] + ["right"] + [rank]
363
+
364
+
365
+ def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
366
+ base_image = Image.fromarray(image).convert("RGBA")
367
+ base_image = base_image.resize((512, 512), Image.ANTIALIAS)
368
+ if vote_level == 0:
369
+ border_color = COLOR1
370
+ elif vote_level == 1:
371
+ border_color = COLOR2
372
+ elif vote_level == 2:
373
+ border_color = COLOR3
374
+ elif vote_level == 3:
375
+ border_color = COLOR4
376
+ border_size = 10 # Size of the border
377
+ base_image = ImageOps.expand(base_image, border=border_size, fill=border_color)
378
+
379
+ draw = ImageDraw.Draw(base_image)
380
+ font = ImageFont.truetype("./serve/Arial.ttf", 66)
381
+
382
+ text_position = (180, 25)
383
+ if vote_level == 0:
384
+ text_color = COLOR1
385
+ draw.text(text_position, Top1_text, font=font, fill=text_color)
386
+ elif vote_level == 1:
387
+ text_color = COLOR2
388
+ draw.text(text_position, Top2_text, font=font, fill=text_color)
389
+ elif vote_level == 2:
390
+ text_color = COLOR3
391
+ draw.text(text_position, Top3_text, font=font, fill=text_color)
392
+ elif vote_level == 3:
393
+ text_color = COLOR4
394
+ draw.text(text_position, Top4_text, font=font, fill=text_color)
395
+
396
+ base_image = base_image.convert("RGB")
397
+ return base_image
398
+
399
+
400
+ def add_green_border(image):
401
+ border_color = (0, 255, 0) # RGB for green
402
+ border_size = 10 # Size of the border
403
+ img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
404
+ return img_with_border
405
+
406
+
407
+ def check_textbox(textbox):
408
+ if textbox=="":
409
+ return False
410
+ else:
411
+ return True
serve/__init__.py ADDED
File without changes
serve/__pycache__/Ksort.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
serve/__pycache__/Ksort.cpython-39.pyc ADDED
Binary file (11.7 kB). View file
 
serve/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
serve/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
serve/__pycache__/constants.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
serve/__pycache__/constants.cpython-39.pyc ADDED
Binary file (1.84 kB). View file
 
serve/__pycache__/gradio_web.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
serve/__pycache__/gradio_web.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
serve/__pycache__/gradio_web_bbox.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
serve/__pycache__/leaderboard.cpython-310.pyc ADDED
Binary file (7.97 kB). View file
 
serve/__pycache__/log_utils.cpython-310.pyc ADDED
Binary file (4.06 kB). View file
 
serve/__pycache__/log_utils.cpython-39.pyc ADDED
Binary file (4.05 kB). View file
 
serve/__pycache__/update_skill.cpython-310.pyc ADDED
Binary file (4 kB). View file
 
serve/__pycache__/upload.cpython-310.pyc ADDED
Binary file (8.22 kB). View file
 
serve/__pycache__/upload.cpython-39.pyc ADDED
Binary file (8.28 kB). View file
 
serve/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.54 kB). View file
 
serve/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.14 kB). View file
 
serve/__pycache__/vote_utils.cpython-310.pyc ADDED
Binary file (30.4 kB). View file
 
serve/__pycache__/vote_utils.cpython-39.pyc ADDED
Binary file (33.7 kB). View file
 
serve/button.css ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* style.css */
2
+ .custom-button {
3
+ background-color: greenyellow; /* 背景颜色 */
4
+ color: white; /* 文字颜色 */
5
+ border: none; /* 无边框 */
6
+ padding: 10px 20px; /* 内边距 */
7
+ text-align: center; /* 文本居中 */
8
+ text-decoration: none; /* 无下划线 */
9
+ display: inline-block; /* 行内块 */
10
+ font-size: 16px; /* 字体大小 */
11
+ margin: 4px 2px; /* 外边距 */
12
+ cursor: pointer; /* 鼠标指针 */
13
+ border-radius: 5px; /* 圆角边框 */
14
+ }
15
+
16
+ .custom-button:hover {
17
+ background-color: darkgreen; /* 悬停时的背景颜色 */
18
+ }
19
+ /* css = """
20
+ #warning {background: red;}
21
+
22
+ .feedback {font-size: 24px !important;}
23
+ .feedback textarea {font-size: 24px !important;}
24
+ """ */
serve/constants.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ LOGDIR = os.getenv("LOGDIR", "./ksort-logs/vote_log")
4
+ IMAGE_DIR = os.getenv("IMAGE_DIR", f"{LOGDIR}/images")
5
+ KSORT_IMAGE_DIR = os.getenv("KSORT_IMAGE_DIR", f"{LOGDIR}/ksort_images")
6
+ VIDEO_DIR = os.getenv("VIDEO_DIR", f"{LOGDIR}/videos")
7
+
8
+ SERVER_PORT = os.getenv("SERVER_PORT", 7860)
9
+ ROOT_PATH = os.getenv("ROOT_PATH", None)
10
+ ELO_RESULTS_DIR = os.getenv("ELO_RESULTS_DIR", "./arena_elo/results/latest")
11
+
12
+ LOG_SERVER = os.getenv("LOG_SERVER", "http://127.0.0.1:22005")
13
+ LOG_SERVER_SUBDOMAIN = os.getenv("LOG_SERVER_SUBDOMAIN", "/ksort-logs")
14
+ LOG_SERVER_ADDR = os.getenv("LOG_SERVER_ADDR", f"{LOG_SERVER}{LOG_SERVER_SUBDOMAIN}")
15
+ # LOG SERVER API ENDPOINTS
16
+ APPEND_JSON = "append_json"
17
+ SAVE_IMAGE = "save_image"
18
+ SAVE_VIDEO = "save_video"
19
+ SAVE_LOG = "save_log"
20
+
21
+ SSH_MSCOCO = os.getenv("SSH_MSCOCO", "/root/MSCOCO_prompt")
22
+
23
+
24
+ SSH_SERVER = os.getenv("SSH_SERVER", "default_value")
25
+ SSH_PASSWORD = os.getenv("SSH_PASSWORD", "default_value")
26
+ SSH_PORT = os.getenv("SSH_PORT", "default_value")
27
+ SSH_USER = os.getenv("SSH_USER", "default_value")
28
+ SSH_LOG = os.getenv("SSH_LOG", "/home/zhendongucb/ksort/ksort_log/log")
29
+ SSH_VIDEO_LOG = os.getenv("SSH_LOG", "/home/zhendongucb/ksort/ksort_log/video_log")
30
+ SSH_SKILL = os.getenv("SSH_SKILL", "/home/zhendongucb/ksort/TrueSkill/trueskill_data.json")
31
+ SSH_VIDEO_SKILL = os.getenv("SSH_VIDEO_SKILL", "/home/zhendongucb/ksort/TrueSkill/trueskill_video_data.json")
32
+
33
+ SSH_CACHE_OPENSOURCE = os.getenv("SSH_CACHE_OPENSOURCE", "/home/zhendongucb/ksort/ksort_video_cache/Kaiyuan/")
34
+ SSH_CACHE_ADVANCE = os.getenv("SSH_CACHE_ADVANCE", "/home/zhendongucb/ksort/ksort_video_cache/Advance/")
35
+ SSH_CACHE_PIKA = os.getenv("SSH_CACHE_PIKA", "/home/zhendongucb/ksort/ksort_video_cache/Pika-Beta/")
36
+ SSH_CACHE_SORA = os.getenv("SSH_CACHE_SORA", "/home/zhendongucb/ksort/ksort_video_cache/Sora/")
37
+
38
+ SSH_CACHE_IMAGE = os.getenv("SSH_CACHE_IMAGE", "/home/zhendongucb/ksort/ksort_image_cache/")
39
+
40
+ # COLOR1=(128, 214, 255)
41
+ # COLOR2=(237, 247, 152)
42
+ # COLOR3=(250, 181, 122)
43
+ # COLOR4=(240, 104, 104)
44
+
45
+ # COLOR1=(112, 161, 215)
46
+ # COLOR2=(161, 222, 147)
47
+ # COLOR3=(247, 244, 139)
48
+ # COLOR4=(244, 124, 124)
49
+
50
+ COLOR1=(168, 230, 207)
51
+ COLOR2=(253, 255, 171)
52
+ COLOR3=(255, 211, 182)
53
+ COLOR4=(255, 170, 165)
54
+
55
+ # COLOR1=(255, 212, 96)
56
+ # COLOR2=(240, 123, 63)
57
+ # COLOR3=(234, 84, 85)
58
+ # COLOR4=(45, 64, 89)
59
+
60
+ # COLOR1=(255, 189, 57)
61
+ # COLOR2=(230, 28, 93)
62
+ # COLOR3=(147, 0, 119)
63
+ # COLOR4=(58, 0, 136)
serve/gradio_web.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import *
2
+ from .vote_utils import (
3
+ upvote_last_response_ig as upvote_last_response,
4
+ downvote_last_response_ig as downvote_last_response,
5
+ flag_last_response_ig as flag_last_response,
6
+ leftvote_last_response_igm as leftvote_last_response,
7
+ left1vote_last_response_igm as left1vote_last_response,
8
+ rightvote_last_response_igm as rightvote_last_response,
9
+ right1vote_last_response_igm as right1vote_last_response,
10
+ tievote_last_response_igm as tievote_last_response,
11
+ bothbad_vote_last_response_igm as bothbad_vote_last_response,
12
+ share_click_igm as share_click,
13
+ generate_ig,
14
+ generate_ig_museum,
15
+ generate_igm,
16
+ generate_igm_museum,
17
+ generate_igm_annoy,
18
+ generate_igm_annoy_museum,
19
+ generate_igm_cache_annoy,
20
+ share_js
21
+ )
22
+ from .Ksort import (
23
+ add_foreground,
24
+ reset_level,
25
+ reset_rank,
26
+ revote_windows,
27
+ submit_response_igm,
28
+ submit_response_rank_igm,
29
+ reset_submit,
30
+ clear_rank,
31
+ reset_mode,
32
+ reset_chatbot,
33
+ reset_btn_rank,
34
+ reset_vote_text,
35
+ text_response_rank_igm,
36
+ check_textbox,
37
+ )
38
+
39
+ from functools import partial
40
+ from .upload import get_random_mscoco_prompt
41
+ from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD
42
+ from serve.upload import get_random_mscoco_prompt, create_ssh_client
43
+ from serve.update_skill import create_ssh_skill_client
44
+ from model.matchmaker import create_ssh_matchmaker_client
45
+ def set_ssh():
46
+ create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
47
+ create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
48
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
49
+
50
+ def build_side_by_side_ui_anony(models):
51
+ notice_markdown = """
52
+ # ⚔️ K-Sort Arena (Text-to-Image Generation) ⚔️
53
+ ## 📜 Rules
54
+ - Input a prompt for four anonymized models (e.g., SD, SDXL, OpenJourney for Text-guided Image Generation) and vote on their outputs.
55
+ - Two voting modes available: Rank Mode and Best Mode. Switch freely between modes. Please note that ties are always allowed. In ranking mode, users can input rankings like 1 3 3 1. Any invalid rankings, such as 1 4 4 1, will be automatically corrected during post-processing.
56
+ - Users are encouraged to make evaluations based on subjective preferences. Evaluation criteria: Alignment (50%) + Aesthetics (50%).
57
+ - Alignment includes: Entity Matching (30%) + Style Matching (20%);
58
+ - Aesthetics includes: Photorealism (30%) + Light and Shadow (10%) + Absence of Artifacts (10%).
59
+
60
+ ## 👇 Generating now!
61
+ - Note: Due to the API's image safety checks, errors may occur. If this happens, please re-enter a different prompt.
62
+ - At times, high API concurrency can cause congestion, potentially resulting in a generation time of up to 1.5 minutes per image. Thank you for your patience.
63
+ """
64
+
65
+ model_list = models.model_ig_list
66
+
67
+ state0 = gr.State()
68
+ state1 = gr.State()
69
+ state2 = gr.State()
70
+ state3 = gr.State()
71
+
72
+ gen_func = partial(generate_igm_annoy, models.generate_image_ig_parallel_anony)
73
+ gen_cache_func = partial(generate_igm_cache_annoy, models.generate_image_ig_cache_anony)
74
+
75
+ # gen_func_random = partial(generate_igm_annoy_museum, models.generate_image_ig_museum_parallel_anony)
76
+
77
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
78
+
79
+ with gr.Group(elem_id="share-region-anony"):
80
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
81
+ model_description_md = get_model_description_md(model_list)
82
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
83
+ with gr.Row():
84
+ with gr.Column():
85
+ chatbot_left = gr.Image(width=512, label = "Model A")
86
+ with gr.Column():
87
+ chatbot_left1 = gr.Image(width=512, label = "Model B")
88
+ with gr.Column():
89
+ chatbot_right = gr.Image(width=512, label = "Model C")
90
+ with gr.Column():
91
+ chatbot_right1 = gr.Image(width=512, label = "Model D")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ model_selector_left = gr.Markdown("", visible=False)
96
+ with gr.Column():
97
+ model_selector_left1 = gr.Markdown("", visible=False)
98
+ with gr.Column():
99
+ model_selector_right = gr.Markdown("", visible=False)
100
+ with gr.Column():
101
+ model_selector_right1 = gr.Markdown("", visible=False)
102
+ with gr.Row():
103
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
104
+
105
+ with gr.Row(elem_classes="row"):
106
+ with gr.Column(scale=1, min_width=10):
107
+ leftvote_btn = gr.Button(
108
+ value="A is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
109
+ )
110
+ with gr.Column(scale=1, min_width=10):
111
+ left1vote_btn = gr.Button(
112
+ value="B is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
113
+ )
114
+ with gr.Column(scale=1, min_width=10):
115
+ rightvote_btn = gr.Button(
116
+ value="C is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
117
+ )
118
+ with gr.Column(scale=1, min_width=10):
119
+ right1vote_btn = gr.Button(
120
+ value="D is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
121
+ )
122
+ with gr.Column(scale=1, min_width=10):
123
+ tie_btn = gr.Button(
124
+ value="🤝 Tie", visible=False, interactive=False, elem_id="btncolor2", elem_classes="best-button"
125
+ )
126
+
127
+ with gr.Row():
128
+ with gr.Blocks():
129
+ with gr.Row():
130
+ with gr.Column(scale=1, min_width=10):
131
+ A1_btn = gr.Button(
132
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
133
+ )
134
+ with gr.Column(scale=1, min_width=10):
135
+ A2_btn = gr.Button(
136
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
137
+ )
138
+ with gr.Column(scale=1, min_width=10):
139
+ A3_btn = gr.Button(
140
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
141
+ )
142
+ with gr.Column(scale=1, min_width=10):
143
+ A4_btn = gr.Button(
144
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
145
+ )
146
+ with gr.Blocks():
147
+ with gr.Row():
148
+ with gr.Column(scale=1, min_width=10):
149
+ B1_btn = gr.Button(
150
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
151
+ )
152
+ with gr.Column(scale=1, min_width=10):
153
+ B2_btn = gr.Button(
154
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
155
+ )
156
+ with gr.Column(scale=1, min_width=10):
157
+ B3_btn = gr.Button(
158
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
159
+ )
160
+ with gr.Column(scale=1, min_width=10):
161
+ B4_btn = gr.Button(
162
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
163
+ )
164
+ with gr.Blocks():
165
+ with gr.Row():
166
+ with gr.Column(scale=1, min_width=10):
167
+ C1_btn = gr.Button(
168
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
169
+ )
170
+ with gr.Column(scale=1, min_width=10):
171
+ C2_btn = gr.Button(
172
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
173
+ )
174
+ with gr.Column(scale=1, min_width=10):
175
+ C3_btn = gr.Button(
176
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
177
+ )
178
+ with gr.Column(scale=1, min_width=10):
179
+ C4_btn = gr.Button(
180
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
181
+ )
182
+ with gr.Blocks():
183
+ with gr.Row():
184
+ with gr.Column(scale=1, min_width=10):
185
+ D1_btn = gr.Button(
186
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
187
+ )
188
+ with gr.Column(scale=1, min_width=10):
189
+ D2_btn = gr.Button(
190
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
191
+ )
192
+ with gr.Column(scale=1, min_width=10):
193
+ D3_btn = gr.Button(
194
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
195
+ )
196
+ with gr.Column(scale=1, min_width=10):
197
+ D4_btn = gr.Button(
198
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
199
+ )
200
+
201
+ with gr.Row():
202
+ vote_textbox = gr.Textbox(
203
+ show_label=False,
204
+ placeholder="👉 Enter your rank (you can use buttons above, or directly type here, e.g. 1 2 3 4)",
205
+ container=True,
206
+ elem_id="input_box",
207
+ visible=False,
208
+ )
209
+ vote_submit_btn = gr.Button(value="Submit", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
210
+ vote_mode_btn = gr.Button(value="🔄 Mode", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
211
+
212
+ with gr.Row():
213
+ textbox = gr.Textbox(
214
+ show_label=False,
215
+ placeholder="👉 Enter your prompt and press ENTER",
216
+ container=True,
217
+ elem_id="input_box",
218
+ )
219
+ # send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue", elem_classes="send-button")
220
+ # draw_btn = gr.Button(value="🎲 Random sample", variant="primary", scale=0, elem_id="btnblue", elem_classes="send-button")
221
+ send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue")
222
+ draw_btn = gr.Button(value="🎲 Random Prompt", variant="primary", scale=0, elem_id="btnblue")
223
+ with gr.Row():
224
+ cache_btn = gr.Button(value="🎲 Random Sample", interactive=True)
225
+ with gr.Row():
226
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
227
+ # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
228
+ # share_btn = gr.Button(value="📷 Share")
229
+ with gr.Blocks():
230
+ with gr.Row(elem_id="centered-text"): #
231
+ user_info = gr.Markdown("User information (to appear on the contributor leaderboard)", visible=True, elem_id="centered-text") #, elem_id="centered-text"
232
+ # with gr.Blocks():
233
+ # name = gr.Markdown("Name", visible=True)
234
+ user_name = gr.Textbox(show_label=False,placeholder="👉 Enter your name (optional)", elem_classes="custom-width")
235
+ # with gr.Blocks():
236
+ # institution = gr.Markdown("Institution", visible=True)
237
+ user_institution = gr.Textbox(show_label=False,placeholder="👉 Enter your affiliation (optional)", elem_classes="custom-width")
238
+ #gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
239
+ dummy_img_output = gr.Image(width=512, visible=False)
240
+ gr.Examples(
241
+ examples=[["A train crossing a bridge that is going over a body of water.", os.path.join("./examples", "example1.jpg")],
242
+ ["The man in the business suit wears a striped blue and white tie.", os.path.join("./examples", "example2.jpg")],
243
+ ["A skier stands on a small ledge in the snow.",os.path.join("./examples", "example3.jpg")],
244
+ ["The bathroom with green tile and a red shower curtain.", os.path.join("./examples", "example4.jpg")]],
245
+ inputs = [textbox, dummy_img_output])
246
+
247
+ order_btn_list = [textbox, send_btn, draw_btn, cache_btn, clear_btn]
248
+ vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
249
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
250
+ vote_textbox, vote_submit_btn, vote_mode_btn]
251
+
252
+ generate_ig0 = gr.Image(width=512, label = "generate A", visible=False, interactive=False)
253
+ generate_ig1 = gr.Image(width=512, label = "generate B", visible=False, interactive=False)
254
+ generate_ig2 = gr.Image(width=512, label = "generate C", visible=False, interactive=False)
255
+ generate_ig3 = gr.Image(width=512, label = "generate D", visible=False, interactive=False)
256
+ dummy_left_model = gr.State("")
257
+ dummy_left1_model = gr.State("")
258
+ dummy_right_model = gr.State("")
259
+ dummy_right1_model = gr.State("")
260
+
261
+ ig_rank = [None, None, None, None]
262
+ bastA_rank = [0, 3, 3, 3]
263
+ bastB_rank = [3, 0, 3, 3]
264
+ bastC_rank = [3, 3, 0, 3]
265
+ bastD_rank = [3, 3, 3, 0]
266
+ tie_rank = [0, 0, 0, 0]
267
+ bad_rank = [3, 3, 3, 3]
268
+ rank = gr.State(ig_rank)
269
+ rankA = gr.State(bastA_rank)
270
+ rankB = gr.State(bastB_rank)
271
+ rankC = gr.State(bastC_rank)
272
+ rankD = gr.State(bastD_rank)
273
+ rankTie = gr.State(tie_rank)
274
+ rankBad = gr.State(bad_rank)
275
+ Top1_text = gr.Textbox(value="Top 1", visible=False, interactive=False)
276
+ Top2_text = gr.Textbox(value="Top 2", visible=False, interactive=False)
277
+ Top3_text = gr.Textbox(value="Top 3", visible=False, interactive=False)
278
+ Top4_text = gr.Textbox(value="Top 4", visible=False, interactive=False)
279
+ window1_text = gr.Textbox(value="Model A", visible=False, interactive=False)
280
+ window2_text = gr.Textbox(value="Model B", visible=False, interactive=False)
281
+ window3_text = gr.Textbox(value="Model C", visible=False, interactive=False)
282
+ window4_text = gr.Textbox(value="Model D", visible=False, interactive=False)
283
+ vote_level = gr.Number(value=0, visible=False, interactive=False)
284
+ # Top1_btn.click(reset_level, inputs=[Top1_text], outputs=[vote_level])
285
+ # Top2_btn.click(reset_level, inputs=[Top2_text], outputs=[vote_level])
286
+ # Top3_btn.click(reset_level, inputs=[Top3_text], outputs=[vote_level])
287
+ # Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level])
288
+ vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False)
289
+ right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False)
290
+ cache_mode = gr.Textbox(value="True", visible=False, interactive=False)
291
+
292
+ textbox.submit(
293
+ disable_order_buttons,
294
+ inputs=[textbox],
295
+ outputs=order_btn_list
296
+ ).then(
297
+ gen_func,
298
+ inputs=[state0, state1, state2, state3, textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
299
+ outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
300
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
301
+ api_name="submit_btn_annony"
302
+ ).then(
303
+ enable_vote_mode_buttons,
304
+ inputs=[vote_mode, textbox],
305
+ outputs=vote_order_list
306
+ )
307
+
308
+ send_btn.click(
309
+ disable_order_buttons,
310
+ inputs=[textbox],
311
+ outputs=order_btn_list
312
+ ).then(
313
+ gen_func,
314
+ inputs=[state0, state1, state2, state3, textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
315
+ outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
316
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
317
+ api_name="send_btn_annony"
318
+ ).then(
319
+ enable_vote_mode_buttons,
320
+ inputs=[vote_mode, textbox],
321
+ outputs=vote_order_list
322
+ )
323
+
324
+ cache_btn.click(
325
+ disable_order_buttons,
326
+ inputs=[textbox, cache_mode],
327
+ outputs=order_btn_list
328
+ ).then(
329
+ gen_cache_func,
330
+ inputs=[state0, state1, state2, state3, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
331
+ outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
332
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, textbox],
333
+ api_name="send_btn_annony"
334
+ ).then(
335
+ enable_vote_mode_buttons,
336
+ inputs=[vote_mode, textbox],
337
+ outputs=vote_order_list
338
+ )
339
+
340
+ draw_btn.click(
341
+ get_random_mscoco_prompt,
342
+ inputs=None,
343
+ outputs=[textbox],
344
+ api_name="draw_btn_annony"
345
+ )
346
+
347
+ clear_btn.click(
348
+ clear_history_side_by_side_anony,
349
+ inputs=None,
350
+ outputs=[state0, state1, state2, state3, textbox, vote_textbox, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
351
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
352
+ api_name="clear_btn_annony"
353
+ ).then(
354
+ enable_order_buttons,
355
+ inputs=None,
356
+ outputs=order_btn_list
357
+ ).then(
358
+ clear_rank,
359
+ inputs=[rank, vote_level],
360
+ outputs=[rank, vote_level]
361
+ ).then(
362
+ disable_vote_mode_buttons,
363
+ inputs=None,
364
+ outputs=vote_order_list
365
+ )
366
+
367
+ # regenerate_btn.click(
368
+ # gen_func,
369
+ # inputs=[state0, state1, state2, state3, textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
370
+ # outputs=[state0, state1, state2, state3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
371
+ # model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
372
+ # api_name="regenerate_btn_annony"
373
+ # ).then(
374
+ # enable_best_buttons,
375
+ # inputs=None,
376
+ # outputs=btn_list
377
+ # )
378
+ vote_mode_btn.click(
379
+ reset_chatbot,
380
+ inputs=[vote_mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3],
381
+ outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1]
382
+ ).then(
383
+ reset_mode,
384
+ inputs=[vote_mode],
385
+ outputs=[leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
386
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
387
+ vote_textbox, vote_submit_btn, vote_mode_btn, vote_mode]
388
+ )
389
+
390
+ vote_textbox.submit(
391
+ disable_vote,
392
+ inputs=None,
393
+ outputs=[vote_submit_btn, vote_mode_btn, \
394
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn]
395
+ ).then(
396
+ text_response_rank_igm,
397
+ inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
398
+ outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, vote_textbox, right_vote_text, rank]
399
+ ).then(
400
+ submit_response_rank_igm,
401
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rank, right_vote_text, user_name, user_institution],
402
+ outputs=[A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
403
+ vote_textbox, vote_submit_btn, vote_mode_btn, right_vote_text, \
404
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
405
+ api_name="submit_btn_annony"
406
+ )
407
+ vote_submit_btn.click(
408
+ disable_vote,
409
+ inputs=None,
410
+ outputs=[vote_submit_btn, vote_mode_btn, \
411
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn]
412
+ ).then(
413
+ text_response_rank_igm,
414
+ inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
415
+ outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, vote_textbox, right_vote_text, rank]
416
+ ).then(
417
+ submit_response_rank_igm,
418
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rank, right_vote_text, user_name, user_institution],
419
+ outputs=[A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
420
+ vote_textbox, vote_submit_btn, vote_mode_btn, right_vote_text, \
421
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
422
+ api_name="submit_btn_annony"
423
+ )
424
+
425
+ # Revote_btn.click(
426
+ # revote_windows,
427
+ # inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level],
428
+ # outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, rank, vote_level]
429
+ # ).then(
430
+ # reset_submit,
431
+ # inputs = [rank],
432
+ # outputs = [Submit_btn]
433
+ # )
434
+ # Submit_btn.click(
435
+ # submit_response_igm,
436
+ # inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, rank],
437
+ # outputs=[textbox, Top1_btn, Top2_btn, Top3_btn, Top4_btn, Revote_btn, Submit_btn, \
438
+ # model_selector_left, model_selector_left1, model_selector_right, model_selector_right1]
439
+ # )
440
+
441
+ # chatbot_left.select(add_foreground, inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_left]).then(
442
+ # reset_rank,
443
+ # inputs = [window1_text, rank, vote_level],
444
+ # outputs = [rank]
445
+ # ).then(
446
+ # reset_submit,
447
+ # inputs = [rank],
448
+ # outputs = [Submit_btn]
449
+ # )
450
+ # chatbot_left1.select(add_foreground, inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_left1]).then(
451
+ # reset_rank,
452
+ # inputs = [window2_text, rank, vote_level],
453
+ # outputs = [rank]
454
+ # ).then(
455
+ # reset_submit,
456
+ # inputs = [rank],
457
+ # outputs = [Submit_btn]
458
+ # )
459
+ # chatbot_right.select(add_foreground, inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_right]).then(
460
+ # reset_rank,
461
+ # inputs = [window3_text, rank, vote_level],
462
+ # outputs = [rank]
463
+ # ).then(
464
+ # reset_submit,
465
+ # inputs = [rank],
466
+ # outputs = [Submit_btn]
467
+ # )
468
+ # chatbot_right1.select(add_foreground, inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_right1]).then(
469
+ # reset_rank,
470
+ # inputs = [window4_text, rank, vote_level],
471
+ # outputs = [rank]
472
+ # ).then(
473
+ # reset_submit,
474
+ # inputs = [rank],
475
+ # outputs = [Submit_btn]
476
+ # )
477
+
478
+
479
+ leftvote_btn.click(
480
+ submit_response_igm,
481
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankA, user_name, user_institution],
482
+ outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
483
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
484
+ vote_mode_btn]
485
+ )
486
+ left1vote_btn.click(
487
+ submit_response_igm,
488
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankB, user_name, user_institution],
489
+ outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
490
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
491
+ vote_mode_btn]
492
+ )
493
+ rightvote_btn.click(
494
+ submit_response_igm,
495
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankC, user_name, user_institution],
496
+ outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
497
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
498
+ vote_mode_btn]
499
+ )
500
+ right1vote_btn.click(
501
+ submit_response_igm,
502
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankD, user_name, user_institution],
503
+ outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
504
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
505
+ vote_mode_btn]
506
+ )
507
+ tie_btn.click(
508
+ submit_response_igm,
509
+ inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankTie, user_name, user_institution],
510
+ outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
511
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
512
+ vote_mode_btn]
513
+ )
514
+
515
+ A1_btn.click(
516
+ reset_btn_rank,
517
+ inputs=[window1_text, rank, A1_btn, vote_level],
518
+ outputs=[rank, vote_level]
519
+ ).then(
520
+ add_foreground,
521
+ inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
522
+ outputs=[chatbot_left]
523
+ ).then(
524
+ reset_submit,
525
+ inputs = [rank],
526
+ outputs = [vote_submit_btn]
527
+ ).then(
528
+ reset_vote_text,
529
+ inputs = [rank],
530
+ outputs = [vote_textbox]
531
+ )
532
+ A2_btn.click(
533
+ reset_btn_rank,
534
+ inputs=[window1_text, rank, A2_btn, vote_level],
535
+ outputs=[rank, vote_level]
536
+ ).then(
537
+ add_foreground,
538
+ inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
539
+ outputs=[chatbot_left]
540
+ ).then(
541
+ reset_submit,
542
+ inputs = [rank],
543
+ outputs = [vote_submit_btn]
544
+ ).then(
545
+ reset_vote_text,
546
+ inputs = [rank],
547
+ outputs = [vote_textbox]
548
+ )
549
+ A3_btn.click(
550
+ reset_btn_rank,
551
+ inputs=[window1_text, rank, A3_btn, vote_level],
552
+ outputs=[rank, vote_level]
553
+ ).then(
554
+ add_foreground,
555
+ inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
556
+ outputs=[chatbot_left]
557
+ ).then(
558
+ reset_submit,
559
+ inputs = [rank],
560
+ outputs = [vote_submit_btn]
561
+ ).then(
562
+ reset_vote_text,
563
+ inputs = [rank],
564
+ outputs = [vote_textbox]
565
+ )
566
+ A4_btn.click(
567
+ reset_btn_rank,
568
+ inputs=[window1_text, rank, A4_btn, vote_level],
569
+ outputs=[rank, vote_level]
570
+ ).then(
571
+ add_foreground,
572
+ inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
573
+ outputs=[chatbot_left]
574
+ ).then(
575
+ reset_submit,
576
+ inputs = [rank],
577
+ outputs = [vote_submit_btn]
578
+ ).then(
579
+ reset_vote_text,
580
+ inputs = [rank],
581
+ outputs = [vote_textbox]
582
+ )
583
+
584
+ B1_btn.click(
585
+ reset_btn_rank,
586
+ inputs=[window2_text, rank, B1_btn, vote_level],
587
+ outputs=[rank, vote_level]
588
+ ).then(
589
+ add_foreground,
590
+ inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
591
+ outputs=[chatbot_left1]
592
+ ).then(
593
+ reset_submit,
594
+ inputs = [rank],
595
+ outputs = [vote_submit_btn]
596
+ ).then(
597
+ reset_vote_text,
598
+ inputs = [rank],
599
+ outputs = [vote_textbox]
600
+ )
601
+ B2_btn.click(
602
+ reset_btn_rank,
603
+ inputs=[window2_text, rank, B2_btn, vote_level],
604
+ outputs=[rank, vote_level]
605
+ ).then(
606
+ add_foreground,
607
+ inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
608
+ outputs=[chatbot_left1]
609
+ ).then(
610
+ reset_submit,
611
+ inputs = [rank],
612
+ outputs = [vote_submit_btn]
613
+ ).then(
614
+ reset_vote_text,
615
+ inputs = [rank],
616
+ outputs = [vote_textbox]
617
+ )
618
+ B3_btn.click(
619
+ reset_btn_rank,
620
+ inputs=[window2_text, rank, B3_btn, vote_level],
621
+ outputs=[rank, vote_level]
622
+ ).then(
623
+ add_foreground,
624
+ inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
625
+ outputs=[chatbot_left1]
626
+ ).then(
627
+ reset_submit,
628
+ inputs = [rank],
629
+ outputs = [vote_submit_btn]
630
+ ).then(
631
+ reset_vote_text,
632
+ inputs = [rank],
633
+ outputs = [vote_textbox]
634
+ )
635
+ B4_btn.click(
636
+ reset_btn_rank,
637
+ inputs=[window2_text, rank, B4_btn, vote_level],
638
+ outputs=[rank, vote_level]
639
+ ).then(
640
+ add_foreground,
641
+ inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
642
+ outputs=[chatbot_left1]
643
+ ).then(
644
+ reset_submit,
645
+ inputs = [rank],
646
+ outputs = [vote_submit_btn]
647
+ ).then(
648
+ reset_vote_text,
649
+ inputs = [rank],
650
+ outputs = [vote_textbox]
651
+ )
652
+
653
+ C1_btn.click(
654
+ reset_btn_rank,
655
+ inputs=[window3_text, rank, C1_btn, vote_level],
656
+ outputs=[rank, vote_level]
657
+ ).then(
658
+ add_foreground,
659
+ inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
660
+ outputs=[chatbot_right]
661
+ ).then(
662
+ reset_submit,
663
+ inputs = [rank],
664
+ outputs = [vote_submit_btn]
665
+ ).then(
666
+ reset_vote_text,
667
+ inputs = [rank],
668
+ outputs = [vote_textbox]
669
+ )
670
+ C2_btn.click(
671
+ reset_btn_rank,
672
+ inputs=[window3_text, rank, C2_btn, vote_level],
673
+ outputs=[rank, vote_level]
674
+ ).then(
675
+ add_foreground,
676
+ inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
677
+ outputs=[chatbot_right]
678
+ ).then(
679
+ reset_submit,
680
+ inputs = [rank],
681
+ outputs = [vote_submit_btn]
682
+ ).then(
683
+ reset_vote_text,
684
+ inputs = [rank],
685
+ outputs = [vote_textbox]
686
+ )
687
+ C3_btn.click(
688
+ reset_btn_rank,
689
+ inputs=[window3_text, rank, C3_btn, vote_level],
690
+ outputs=[rank, vote_level]
691
+ ).then(
692
+ add_foreground,
693
+ inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
694
+ outputs=[chatbot_right]
695
+ ).then(
696
+ reset_submit,
697
+ inputs = [rank],
698
+ outputs = [vote_submit_btn]
699
+ ).then(
700
+ reset_vote_text,
701
+ inputs = [rank],
702
+ outputs = [vote_textbox]
703
+ )
704
+ C4_btn.click(
705
+ reset_btn_rank,
706
+ inputs=[window3_text, rank, C4_btn, vote_level],
707
+ outputs=[rank, vote_level]
708
+ ).then(
709
+ add_foreground,
710
+ inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
711
+ outputs=[chatbot_right]
712
+ ).then(
713
+ reset_submit,
714
+ inputs = [rank],
715
+ outputs = [vote_submit_btn]
716
+ ).then(
717
+ reset_vote_text,
718
+ inputs = [rank],
719
+ outputs = [vote_textbox]
720
+ )
721
+
722
+ D1_btn.click(
723
+ reset_btn_rank,
724
+ inputs=[window4_text, rank, D1_btn, vote_level],
725
+ outputs=[rank, vote_level]
726
+ ).then(
727
+ add_foreground,
728
+ inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
729
+ outputs=[chatbot_right1]
730
+ ).then(
731
+ reset_submit,
732
+ inputs = [rank],
733
+ outputs = [vote_submit_btn]
734
+ ).then(
735
+ reset_vote_text,
736
+ inputs = [rank],
737
+ outputs = [vote_textbox]
738
+ )
739
+ D2_btn.click(
740
+ reset_btn_rank,
741
+ inputs=[window4_text, rank, D2_btn, vote_level],
742
+ outputs=[rank, vote_level]
743
+ ).then(
744
+ add_foreground,
745
+ inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
746
+ outputs=[chatbot_right1]
747
+ ).then(
748
+ reset_submit,
749
+ inputs = [rank],
750
+ outputs = [vote_submit_btn]
751
+ ).then(
752
+ reset_vote_text,
753
+ inputs = [rank],
754
+ outputs = [vote_textbox]
755
+ )
756
+ D3_btn.click(
757
+ reset_btn_rank,
758
+ inputs=[window4_text, rank, D3_btn, vote_level],
759
+ outputs=[rank, vote_level]
760
+ ).then(
761
+ add_foreground,
762
+ inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
763
+ outputs=[chatbot_right1]
764
+ ).then(
765
+ reset_submit,
766
+ inputs = [rank],
767
+ outputs = [vote_submit_btn]
768
+ ).then(
769
+ reset_vote_text,
770
+ inputs = [rank],
771
+ outputs = [vote_textbox]
772
+ )
773
+ D4_btn.click(
774
+ reset_btn_rank,
775
+ inputs=[window4_text, rank, D4_btn, vote_level],
776
+ outputs=[rank, vote_level]
777
+ ).then(
778
+ add_foreground,
779
+ inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
780
+ outputs=[chatbot_right1]
781
+ ).then(
782
+ reset_submit,
783
+ inputs = [rank],
784
+ outputs = [vote_submit_btn]
785
+ ).then(
786
+ reset_vote_text,
787
+ inputs = [rank],
788
+ outputs = [vote_textbox]
789
+ )
serve/gradio_web_bbox.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import numpy as np
4
+ from gradio import processing_utils
5
+ from PIL import Image, ImageDraw, ImageFont
6
+
7
+ from serve.utils import *
8
+ from serve.vote_utils import (
9
+ upvote_last_response_ig as upvote_last_response,
10
+ downvote_last_response_ig as downvote_last_response,
11
+ flag_last_response_ig as flag_last_response,
12
+ leftvote_last_response_igm as leftvote_last_response,
13
+ left1vote_last_response_igm as left1vote_last_response,
14
+ rightvote_last_response_igm as rightvote_last_response,
15
+ right1vote_last_response_igm as right1vote_last_response,
16
+ tievote_last_response_igm as tievote_last_response,
17
+ bothbad_vote_last_response_igm as bothbad_vote_last_response,
18
+ share_click_igm as share_click,
19
+ generate_ig,
20
+ generate_ig_museum,
21
+ generate_igm,
22
+ generate_igm_museum,
23
+ generate_igm_annoy,
24
+ generate_igm_annoy_museum,
25
+ generate_igm_cache_annoy,
26
+ share_js
27
+ )
28
+ from serve.Ksort import (
29
+ add_foreground,
30
+ reset_level,
31
+ reset_rank,
32
+ revote_windows,
33
+ submit_response_igm,
34
+ submit_response_rank_igm,
35
+ reset_submit,
36
+ clear_rank,
37
+ reset_mode,
38
+ reset_chatbot,
39
+ reset_btn_rank,
40
+ reset_vote_text,
41
+ text_response_rank_igm,
42
+ check_textbox,
43
+ )
44
+
45
+ from functools import partial
46
+ from serve.upload import get_random_mscoco_prompt
47
+ from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD
48
+ from serve.upload import get_random_mscoco_prompt, create_ssh_client
49
+ from serve.update_skill import create_ssh_skill_client
50
+ from model.matchmaker import create_ssh_matchmaker_client
51
+
52
+ def set_ssh():
53
+ create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
54
+ create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
55
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
56
+
57
+ def binarize(x):
58
+ return (x != 0).astype('uint8') * 255
59
+
60
+ def sized_center_crop(img, cropx, cropy):
61
+ y, x = img.shape[:2]
62
+ startx = x // 2 - (cropx // 2)
63
+ starty = y // 2 - (cropy // 2)
64
+ return img[starty:starty+cropy, startx:startx+cropx]
65
+
66
+ def sized_center_fill(img, fill, cropx, cropy):
67
+ y, x = img.shape[:2]
68
+ startx = x // 2 - (cropx // 2)
69
+ starty = y // 2 - (cropy // 2)
70
+ img[starty:starty+cropy, startx:startx+cropx] = fill
71
+ return img
72
+
73
+ def sized_center_mask(img, cropx, cropy):
74
+ y, x = img.shape[:2]
75
+ startx = x // 2 - (cropx // 2)
76
+ starty = y // 2 - (cropy // 2)
77
+ center_region = img[starty:starty+cropy, startx:startx+cropx].copy()
78
+ img = (img * 0.2).astype('uint8')
79
+ img[starty:starty+cropy, startx:startx+cropx] = center_region
80
+ return img
81
+
82
+ def center_crop(img, HW=None, tgt_size=(512, 512)):
83
+ if HW is None:
84
+ H, W = img.shape[:2]
85
+ HW = min(H, W)
86
+ img = sized_center_crop(img, HW, HW)
87
+ img = Image.fromarray(img)
88
+ img = img.resize(tgt_size)
89
+ return np.array(img)
90
+
91
+
92
+
93
+ def draw_box(boxes=[], labels=[], img=None):
94
+ if len(boxes) == 0 and img is None:
95
+ return None
96
+
97
+ # 确保输入的 img 是 PIL.Image 对象
98
+ if isinstance(img, np.ndarray):
99
+ img = Image.fromarray(img)
100
+
101
+ if img is None:
102
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
103
+
104
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
105
+
106
+ # 创建绘图对象
107
+ draw_obj = ImageDraw.Draw(img)
108
+ font = ImageFont.load_default()
109
+ # 获取字体大小
110
+ font_size = getattr(font, 'size_in_points', 10) # 如果没有 size_in_points 属性,使用默认值 10
111
+
112
+ for bid, box in enumerate(boxes):
113
+ draw_obj.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
114
+ anno_text = labels[bid]
115
+ draw_obj.rectangle(
116
+ [
117
+ box[0],
118
+ box[3] - int(font_size * 1.2),
119
+ box[0] + int((len(anno_text) + 0.8) * font_size * 0.6),
120
+ box[3]
121
+ ],
122
+ outline=colors[bid % len(colors)],
123
+ fill=colors[bid % len(colors)],
124
+ width=4
125
+ )
126
+ draw_obj.text(
127
+ [box[0] + int(font_size * 0.2), box[3] - int(font_size * 1.2)],
128
+ anno_text,
129
+ font=font,
130
+ fill=(255,255,255)
131
+ )
132
+ return img
133
+
134
+
135
+ def draw(input, grounding_texts, new_image_trigger, state):
136
+
137
+ # 确保输入数据中有必要的键
138
+ if isinstance(input, dict):
139
+ background = input.get('background', None)
140
+ print("background.shape", background.shape)
141
+ layers = input.get('layers', 1)
142
+ print("len(layers)", len(layers))
143
+ composite = input.get('composite', None)
144
+ print("composite.shape", composite.shape)
145
+ else:
146
+ # 如果 input 不是字典,直接使用其作为图像
147
+ background = input
148
+ layers = 1
149
+ composite = None
150
+
151
+ # 检查 background 是否有效
152
+ if background is None:
153
+ print("background is None")
154
+ return None
155
+ # background = np.ones((512, 512, 3), dtype='uint8') * 255
156
+
157
+ # 默认使用 composite 作为最终图像,如果没有 composite 则使用 background
158
+ if composite is None:
159
+ print("composite is None")
160
+ image = background
161
+ else:
162
+ image = composite
163
+
164
+ mask = binarize(image)
165
+
166
+ if type(mask) != np.ndarray:
167
+ mask = np.array(mask)
168
+
169
+ if mask.sum() == 0:
170
+ state = {}
171
+
172
+ # 更新状态,如果没有 boxes 和 masks,则初始化它们
173
+ if 'boxes' not in state:
174
+ state['boxes'] = []
175
+
176
+ if 'masks' not in state or len(state['masks']) == 0:
177
+ state['masks'] = []
178
+ last_mask = np.zeros_like(mask)
179
+ else:
180
+ last_mask = state['masks'][-1]
181
+
182
+ if type(mask) == np.ndarray and mask.size > 1:
183
+ diff_mask = mask - last_mask
184
+ else:
185
+ diff_mask = np.zeros([])
186
+
187
+ # 根据 mask 的变化来计算 box 的位置
188
+ if diff_mask.sum() > 0:
189
+ x1x2 = np.where(diff_mask.max(0) != 0)[0]
190
+ y1y2 = np.where(diff_mask.max(1) != 0)[0]
191
+ y1, y2 = y1y2.min(), y1y2.max()
192
+ x1, x2 = x1x2.min(), x1x2.max()
193
+
194
+ if (x2 - x1 > 5) and (y2 - y1 > 5):
195
+ state['masks'].append(mask.copy())
196
+ state['boxes'].append((x1, y1, x2, y2))
197
+
198
+ # 处理 grounding_texts
199
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
200
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
201
+ if len(grounding_texts) < len(state['boxes']):
202
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
203
+
204
+ # 绘制标注框
205
+ box_image = draw_box(state['boxes'], grounding_texts, background)
206
+
207
+ if box_image is not None and state.get('inpaint_hw', None):
208
+ inpaint_hw = state['inpaint_hw']
209
+ box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
210
+ original_image = state['original_image'].copy()
211
+ box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
212
+
213
+ return [box_image, new_image_trigger, 1.0, state]
214
+
215
+ def build_side_by_side_bbox_ui_anony(models):
216
+ notice_markdown = """
217
+ # ⚔️ Control-Ability-Arena (Bbox-to-Image Generation) ⚔️
218
+ ## 📜 Rules
219
+ - Input a prompt for four anonymized models and vote on their outputs.
220
+ - Two voting modes available: Rank Mode and Best Mode. Switch freely between modes. Please note that ties are always allowed. In ranking mode, users can input rankings like 1 3 3 1. Any invalid rankings, such as 1 4 4 1, will be automatically corrected during post-processing.
221
+ - Users are encouraged to make evaluations based on subjective preferences. Evaluation criteria: Alignment (50%) + Aesthetics (50%).
222
+ - Alignment includes: Entity Matching (30%) + Style Matching (20%);
223
+ - Aesthetics includes: Photorealism (30%) + Light and Shadow (10%) + Absence of Artifacts (10%).
224
+
225
+ ## 👇 Generating now!
226
+ - Note: Due to the API's image safety checks, errors may occur. If this happens, please re-enter a different prompt.
227
+ - At times, high API concurrency can cause congestion, potentially resulting in a generation time of up to 1.5 minutes per image. Thank you for your patience.
228
+ """
229
+ model_list = models.model_b2i_list
230
+
231
+ state = gr.State({})
232
+ state0 = gr.State()
233
+ state1 = gr.State()
234
+ state2 = gr.State()
235
+ state3 = gr.State()
236
+
237
+ # gen_func = partial(generate_igm_annoy, models.generate_image_ig_parallel_anony)
238
+ # gen_cache_func = partial(generate_igm_cache_annoy, models.generate_image_ig_cache_anony)
239
+
240
+
241
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
242
+
243
+
244
+ with gr.Row():
245
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
246
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
247
+ image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
248
+
249
+ with gr.Row():
250
+ sketch_pad = gr.ImageEditor(
251
+ label="Sketch Pad",
252
+ type="numpy",
253
+ crop_size="1:1",
254
+ width=512,
255
+ height=512
256
+ )
257
+ out_imagebox = gr.Image(
258
+ type="pil",
259
+ label="Parsed Sketch Pad",
260
+ width=512,
261
+ height=512
262
+ )
263
+
264
+ with gr.Row():
265
+ textbox = gr.Textbox(
266
+ show_label=False,
267
+ placeholder="👉 Enter your prompt and press ENTER",
268
+ container=True,
269
+ elem_id="input_box",
270
+ )
271
+ send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue")
272
+
273
+ with gr.Row():
274
+ grounding_instruction = gr.Textbox(
275
+ label="Grounding instruction (Separated by semicolon)",
276
+ placeholder="👉 Enter your Grounding instruction (e.g. a cat; a dog; a bird; a fish)",
277
+ )
278
+
279
+ with gr.Group(elem_id="share-region-anony"):
280
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
281
+ # model_description_md = get_model_description_md(model_list)
282
+ gr.Markdown("", elem_id="model_description_markdown")
283
+
284
+
285
+ with gr.Row():
286
+ with gr.Column():
287
+ chatbot_left = gr.Image(width=512, label = "Model A")
288
+ with gr.Column():
289
+ chatbot_left1 = gr.Image(width=512, label = "Model B")
290
+ with gr.Column():
291
+ chatbot_right = gr.Image(width=512, label = "Model C")
292
+ with gr.Column():
293
+ chatbot_right1 = gr.Image(width=512, label = "Model D")
294
+
295
+ with gr.Row():
296
+ with gr.Column():
297
+ model_selector_left = gr.Markdown("", visible=False)
298
+ with gr.Column():
299
+ model_selector_left1 = gr.Markdown("", visible=False)
300
+ with gr.Column():
301
+ model_selector_right = gr.Markdown("", visible=False)
302
+ with gr.Column():
303
+ model_selector_right1 = gr.Markdown("", visible=False)
304
+ with gr.Row():
305
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
306
+
307
+ with gr.Row(elem_classes="row"):
308
+ with gr.Column(scale=1, min_width=10):
309
+ leftvote_btn = gr.Button(
310
+ value="A is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
311
+ )
312
+ with gr.Column(scale=1, min_width=10):
313
+ left1vote_btn = gr.Button(
314
+ value="B is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
315
+ )
316
+ with gr.Column(scale=1, min_width=10):
317
+ rightvote_btn = gr.Button(
318
+ value="C is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
319
+ )
320
+ with gr.Column(scale=1, min_width=10):
321
+ right1vote_btn = gr.Button(
322
+ value="D is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
323
+ )
324
+ with gr.Column(scale=1, min_width=10):
325
+ tie_btn = gr.Button(
326
+ value="🤝 Tie", visible=False, interactive=False, elem_id="btncolor2", elem_classes="best-button"
327
+ )
328
+
329
+
330
+ with gr.Row():
331
+ with gr.Blocks():
332
+ with gr.Row():
333
+ with gr.Column(scale=1, min_width=10):
334
+ A1_btn = gr.Button(
335
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
336
+ )
337
+ with gr.Column(scale=1, min_width=10):
338
+ A2_btn = gr.Button(
339
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
340
+ )
341
+ with gr.Column(scale=1, min_width=10):
342
+ A3_btn = gr.Button(
343
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
344
+ )
345
+ with gr.Column(scale=1, min_width=10):
346
+ A4_btn = gr.Button(
347
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
348
+ )
349
+ with gr.Blocks():
350
+ with gr.Row():
351
+ with gr.Column(scale=1, min_width=10):
352
+ B1_btn = gr.Button(
353
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
354
+ )
355
+ with gr.Column(scale=1, min_width=10):
356
+ B2_btn = gr.Button(
357
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
358
+ )
359
+ with gr.Column(scale=1, min_width=10):
360
+ B3_btn = gr.Button(
361
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
362
+ )
363
+ with gr.Column(scale=1, min_width=10):
364
+ B4_btn = gr.Button(
365
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
366
+ )
367
+ with gr.Blocks():
368
+ with gr.Row():
369
+ with gr.Column(scale=1, min_width=10):
370
+ C1_btn = gr.Button(
371
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
372
+ )
373
+ with gr.Column(scale=1, min_width=10):
374
+ C2_btn = gr.Button(
375
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
376
+ )
377
+ with gr.Column(scale=1, min_width=10):
378
+ C3_btn = gr.Button(
379
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
380
+ )
381
+ with gr.Column(scale=1, min_width=10):
382
+ C4_btn = gr.Button(
383
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
384
+ )
385
+ with gr.Blocks():
386
+ with gr.Row():
387
+ with gr.Column(scale=1, min_width=10):
388
+ D1_btn = gr.Button(
389
+ value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
390
+ )
391
+ with gr.Column(scale=1, min_width=10):
392
+ D2_btn = gr.Button(
393
+ value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
394
+ )
395
+ with gr.Column(scale=1, min_width=10):
396
+ D3_btn = gr.Button(
397
+ value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
398
+ )
399
+ with gr.Column(scale=1, min_width=10):
400
+ D4_btn = gr.Button(
401
+ value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
402
+ )
403
+ with gr.Row():
404
+ vote_textbox = gr.Textbox(
405
+ show_label=False,
406
+ placeholder="👉 Enter your rank (you can use buttons above, or directly type here, e.g. 1 2 3 4)",
407
+ container=True,
408
+ elem_id="input_box",
409
+ visible=False,
410
+ )
411
+ vote_submit_btn = gr.Button(value="Submit", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
412
+ vote_mode_btn = gr.Button(value="🔄 Mode", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
413
+
414
+ with gr.Row():
415
+ clear_btn = gr.Button(value="🎲 New Round", interactive=False)
416
+ # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
417
+ # share_btn = gr.Button(value="📷 Share")
418
+ with gr.Blocks():
419
+ with gr.Row(elem_id="centered-text"): #
420
+ user_info = gr.Markdown("User information (to appear on the contributor leaderboard)", visible=True, elem_id="centered-text") #, elem_id="centered-text"
421
+ # with gr.Blocks():
422
+ # name = gr.Markdown("Name", visible=True)
423
+ user_name = gr.Textbox(show_label=False,placeholder="👉 Enter your name (optional)", elem_classes="custom-width")
424
+ # with gr.Blocks():
425
+ # institution = gr.Markdown("Institution", visible=True)
426
+ user_institution = gr.Textbox(show_label=False,placeholder="👉 Enter your affiliation (optional)", elem_classes="custom-width")
427
+
428
+ sketch_pad.change(
429
+ draw,
430
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
431
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
432
+ queue=False,
433
+ )
434
+ grounding_instruction.change(
435
+ draw,
436
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
437
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
438
+ queue=False,
439
+ )
440
+
441
+ order_btn_list = [textbox, send_btn, clear_btn]
442
+ vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
443
+ A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
444
+ vote_textbox, vote_submit_btn, vote_mode_btn]
445
+
446
+ generate_ig0 = gr.Image(width=512, label = "generate A", visible=False, interactive=False)
447
+ generate_ig1 = gr.Image(width=512, label = "generate B", visible=False, interactive=False)
448
+ generate_ig2 = gr.Image(width=512, label = "generate C", visible=False, interactive=False)
449
+ generate_ig3 = gr.Image(width=512, label = "generate D", visible=False, interactive=False)
450
+ dummy_left_model = gr.State("")
451
+ dummy_left1_model = gr.State("")
452
+ dummy_right_model = gr.State("")
453
+ dummy_right1_model = gr.State("")
454
+
455
+ ig_rank = [None, None, None, None]
456
+ bastA_rank = [0, 3, 3, 3]
457
+ bastB_rank = [3, 0, 3, 3]
458
+ bastC_rank = [3, 3, 0, 3]
459
+ bastD_rank = [3, 3, 3, 0]
460
+ tie_rank = [0, 0, 0, 0]
461
+ bad_rank = [3, 3, 3, 3]
462
+ rank = gr.State(ig_rank)
463
+ rankA = gr.State(bastA_rank)
464
+ rankB = gr.State(bastB_rank)
465
+ rankC = gr.State(bastC_rank)
466
+ rankD = gr.State(bastD_rank)
467
+ rankTie = gr.State(tie_rank)
468
+ rankBad = gr.State(bad_rank)
469
+ Top1_text = gr.Textbox(value="Top 1", visible=False, interactive=False)
470
+ Top2_text = gr.Textbox(value="Top 2", visible=False, interactive=False)
471
+ Top3_text = gr.Textbox(value="Top 3", visible=False, interactive=False)
472
+ Top4_text = gr.Textbox(value="Top 4", visible=False, interactive=False)
473
+ window1_text = gr.Textbox(value="Model A", visible=False, interactive=False)
474
+ window2_text = gr.Textbox(value="Model B", visible=False, interactive=False)
475
+ window3_text = gr.Textbox(value="Model C", visible=False, interactive=False)
476
+ window4_text = gr.Textbox(value="Model D", visible=False, interactive=False)
477
+ vote_level = gr.Number(value=0, visible=False, interactive=False)
478
+ # Top1_btn.click(reset_level, inputs=[Top1_text], outputs=[vote_level])
479
+ # Top2_btn.click(reset_level, inputs=[Top2_text], outputs=[vote_level])
480
+ # Top3_btn.click(reset_level, inputs=[Top3_text], outputs=[vote_level])
481
+ # Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level])
482
+ vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False)
483
+ right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False)
484
+ cache_mode = gr.Textbox(value="True", visible=False, interactive=False)
485
+
486
+
487
+
488
+
489
+ if __name__ == "__main__":
490
+ with gr.Blocks() as demo:
491
+ build_side_by_side_bbox_ui_anony()
492
+ demo.launch()
serve/leaderboard.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Live monitor of the website statistics and leaderboard.
3
+
4
+ Dependency:
5
+ sudo apt install pkg-config libicu-dev
6
+ pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
7
+ """
8
+
9
+ import argparse
10
+ import ast
11
+ import pickle
12
+ import os
13
+ import threading
14
+ import time
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import pandas as pd
19
+ import json
20
+ from datetime import datetime
21
+
22
+
23
+ # def make_leaderboard_md(elo_results):
24
+ # leaderboard_md = f"""
25
+ # # 🏆 Chatbot Arena Leaderboard
26
+ # | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
27
+
28
+ # This leaderboard is based on the following three benchmarks.
29
+ # - [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings.
30
+ # - [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses.
31
+ # - [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks.
32
+
33
+ # 💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023.
34
+ # """
35
+ # return leaderboard_md
36
+
37
+ def make_leaderboard_md():
38
+ leaderboard_md = f"""
39
+ # 🏆 K-Sort Arena Leaderboard (Text-to-Image Generation)
40
+ """
41
+ return leaderboard_md
42
+
43
+
44
+ def make_leaderboard_video_md():
45
+ leaderboard_md = f"""
46
+ # 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
47
+ """
48
+ return leaderboard_md
49
+
50
+
51
+ def model_hyperlink(model_name, link):
52
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
53
+
54
+
55
+ def make_arena_leaderboard_md(total_models, total_votes, last_updated):
56
+ # last_updated = datetime.now()
57
+ # last_updated = last_updated.strftime("%Y-%m-%d")
58
+
59
+ leaderboard_md = f"""
60
+ Total models: **{total_models}** (anonymized), Total votes: **{total_votes}** (equivalent to **{total_votes*6}** pairwise comparisons)
61
+ \n Last updated: {last_updated}
62
+ """
63
+
64
+ return leaderboard_md
65
+
66
+
67
+ def make_disclaimer_md():
68
+ disclaimer_md = '''
69
+ <div id="modal" style="display:none; position:fixed; top:50%; left:50%; transform:translate(-50%, -50%); padding:20px; background:white; box-shadow:0 0 10px rgba(0,0,0,0.5); z-index:1000;">
70
+ <p style="font-size:24px;"><strong>Disclaimer</strong></p>
71
+ <p style="font-size:18px;"><b>Purpose and Scope</b></b></p>
72
+ <p><b>This platform is designed for academic use, providing a space for evaluating and comparing Visual Generation Models. The information and services provided are intended for research and educational purposes only.</b></p>
73
+
74
+ <p style="font-size:18px;"><b>Privacy and Data Protection</b></p>
75
+ <p><b>While users may voluntarily submit their names and institutional affiliations, this information is not required and is collected solely for the purpose of academic recognition. Personal information submitted to this platform will be handled with care and used solely for the intended academic purposes. We are committed to protecting your privacy, and we will not share personal data with third parties without explicit consent.</b></p>
76
+
77
+ <p style="font-size:18px;"><b>Source of Models</b></p>
78
+ <p><b>All models evaluated and displayed on this platform are obtained from official sources, including but not limited to official repositories and Replicate.</b></p>
79
+
80
+ <p style="font-size:18px;"><b>Limitations of Liability</b></p>
81
+ <p><b>The platform and its administrators do not assume any legal liability for the use or interpretation of the information provided. The evaluations and comparisons are for academic purposes. Users should verify the information independently and must not use the platform for any illegal, harmful, violent, racist, or sexual purposes.</b></p>
82
+
83
+ <p style="font-size:18px;"><b>Modification of Terms</b></p>
84
+ <p><b>We reserve the right to modify these terms at any time. Users will be notified of significant changes through updates on the platform.</b></p>
85
+
86
+ <p style="font-size:18px;"><b>Contact Information</b></p>
87
+ <p><b>For any questions or to report issues, please contact us at [email protected].</b></p>
88
+ </div>
89
+ <div id="overlay" style="display:none; position:fixed; top:0; left:0; width:100%; height:100%; background:rgba(0,0,0,0.5); z-index:999;" onclick="document.getElementById('modal').style.display='none'; document.getElementById('overlay').style.display='none'"></div>
90
+ <p> This platform is designed for academic usage, for details please refer to <a href="#" id="open_link" onclick="document.getElementById('modal').style.display='block'; document.getElementById('overlay').style.display='block'">disclaimer</a>.</p>
91
+ '''
92
+ return disclaimer_md
93
+
94
+
95
+ def make_arena_leaderboard_data(results):
96
+ import pandas as pd
97
+ df = pd.DataFrame(results)
98
+ return df
99
+
100
+
101
+ def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
102
+ with open(score_result_file, "r") as json_file:
103
+ data = json.load(json_file)
104
+ score_results = data["sorted_score_list"]
105
+ total_models = data["total_models"]
106
+ total_votes = data["total_votes"]
107
+ last_updated = data["last_updated"]
108
+
109
+ md = make_leaderboard_md()
110
+ md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
111
+
112
+ # with gr.Tab("Arena Score", id=0):
113
+ md = make_arena_leaderboard_md(total_models, total_votes, last_updated)
114
+ gr.Markdown(md, elem_id="leaderboard_markdown")
115
+ md = make_arena_leaderboard_data(score_results)
116
+ gr.Dataframe(md)
117
+
118
+ gr.Markdown(
119
+ """
120
+ - Note: When σ is large (we use the '*' labeling), it indicates that the model did not receive enough votes and its ranking is in the process of being updated.
121
+ """,
122
+ elem_id="sigma_note_markdown",
123
+ )
124
+
125
+ gr.Markdown(
126
+ """ ### The leaderboard is regularly updated and continuously incorporates new models.
127
+ """,
128
+ elem_id="leaderboard_markdown",
129
+ )
130
+ with gr.Blocks():
131
+ gr.HTML(make_disclaimer_md)
132
+ from .utils import acknowledgment_md, html_code
133
+ with gr.Blocks():
134
+ gr.Markdown(acknowledgment_md)
135
+
136
+
137
+ def build_leaderboard_video_tab(score_result_file = 'sorted_score_list_video.json'):
138
+ with open(score_result_file, "r") as json_file:
139
+ data = json.load(json_file)
140
+ score_results = data["sorted_score_list"]
141
+ total_models = data["total_models"]
142
+ total_votes = data["total_votes"]
143
+ last_updated = data["last_updated"]
144
+
145
+ md = make_leaderboard_video_md()
146
+ md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
147
+ # with gr.Blocks():
148
+ # gr.HTML(make_disclaimer_md)
149
+
150
+ # with gr.Tab("Arena Score", id=0):
151
+ md = make_arena_leaderboard_md(total_models, total_votes, last_updated)
152
+ gr.Markdown(md, elem_id="leaderboard_markdown")
153
+ md = make_arena_leaderboard_data(score_results)
154
+ gr.Dataframe(md)
155
+
156
+ notice_markdown_sora = """
157
+ - Note: When σ is large (we use the '*' labeling), it indicates that the model did not receive enough votes and its ranking is in the process of being updated.
158
+ - Note: As Sora's video generation function is not publicly available, we used sample videos from their official website. This may lead to a biased assessment of Sora's capabilities, as these samples likely represent Sora's best outputs. Therefore, Sora's position on our leaderboard should be considered as its upper bound. We are working on methods to conduct more comprehensive and fair comparisons in the future.
159
+ """
160
+
161
+ gr.Markdown(notice_markdown_sora, elem_id="notice_markdown_sora")
162
+
163
+ gr.Markdown(
164
+ """ ### The leaderboard is regularly updated and continuously incorporates new models.
165
+ """,
166
+ elem_id="leaderboard_markdown",
167
+ )
168
+ from .utils import acknowledgment_md, html_code
169
+ with gr.Blocks():
170
+ gr.Markdown(acknowledgment_md)
171
+
172
+
173
+ def build_leaderboard_contributor(file = 'contributor.json'):
174
+
175
+ with open(file, "r") as json_file:
176
+ data = json.load(json_file)
177
+ score_results = data["contributor"]
178
+ last_updated = data["last_updated"]
179
+
180
+ md = f"""
181
+ # 🏆 Contributor Leaderboard
182
+ The submission of user information is entirely optional. This information is used solely for contribution statistics. We respect and safeguard users' privacy choices.
183
+ To maintain a clean and concise leaderboard, please ensure consistency in submitted names and affiliations. For example, use 'Berkeley' consistently rather than alternating with 'UC Berkeley'.
184
+ - Votes*: Each image vote counts as one Vote*, while each video vote counts as two Votes* due to the increased effort involved.
185
+ \n Last updated: {last_updated}
186
+ """
187
+
188
+ md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
189
+
190
+ # md = make_arena_leaderboard_md(total_models, total_votes, last_updated)
191
+ # gr.Markdown(md, elem_id="leaderboard_markdown")
192
+
193
+ md = make_arena_leaderboard_data(score_results)
194
+ gr.Dataframe(md)
195
+
196
+ gr.Markdown(
197
+ """ ### The leaderboard is regularly updated.
198
+ """,
199
+ elem_id="leaderboard_markdown",
200
+ )
serve/log_server.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, APIRouter
2
+ from typing import Optional
3
+ import json
4
+ import os
5
+ import aiofiles
6
+ from .log_utils import build_logger
7
+ from .constants import LOG_SERVER_SUBDOMAIN, APPEND_JSON, SAVE_IMAGE, SAVE_VIDEO, SAVE_LOG
8
+
9
+ logger = build_logger("log_server", "log_server.log", add_remote_handler=False)
10
+
11
+ app = APIRouter(prefix=LOG_SERVER_SUBDOMAIN)
12
+
13
+ @app.post(f"/{APPEND_JSON}")
14
+ async def append_json(json_str: str = Form(...), file_name: str = Form(...)):
15
+ """
16
+ Appends a JSON string to a specified file.
17
+ """
18
+ # Convert the string back to a JSON object (dict)
19
+ data = json.loads(json_str)
20
+ # Append the data to the specified file
21
+ if os.path.dirname(file_name):
22
+ os.makedirs(os.path.dirname(file_name), exist_ok=True)
23
+ async with aiofiles.open(file_name, mode='a') as f:
24
+ await f.write(json.dumps(data) + "\n")
25
+
26
+ logger.info(f"Appended 1 JSON object to {file_name}")
27
+ return {"message": "JSON data appended successfully"}
28
+
29
+ @app.post(f"/{SAVE_IMAGE}")
30
+ async def save_image(image: UploadFile = File(...), image_path: str = Form(...)):
31
+ """
32
+ Saves an uploaded image to the specified path.
33
+ """
34
+ # Note: 'image_path' should include the file name and extension for the image to be saved.
35
+ if os.path.dirname(image_path):
36
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
37
+ async with aiofiles.open(image_path, mode='wb') as f:
38
+ content = await image.read() # Read the content of the uploaded image
39
+ await f.write(content) # Write the image content to a file
40
+ logger.info(f"Image saved successfully at {image_path}")
41
+ return {"message": f"Image saved successfully at {image_path}"}
42
+
43
+ @app.post(f"/{SAVE_VIDEO}")
44
+ async def save_video(video: UploadFile = File(...), video_path: str = Form(...)):
45
+ """
46
+ Saves an uploaded video to the specified path.
47
+ """
48
+ # Note: 'video_path' should include the file name and extension for the video to be saved.
49
+ if os.path.dirname(video_path):
50
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
51
+ async with aiofiles.open(video_path, mode='wb') as f:
52
+ content = await video.read() # Read the content of the uploaded video
53
+ await f.write(content) # Write the video content to a file
54
+ logger.info(f"Video saved successfully at {video_path}")
55
+ return {"message": f"Image saved successfully at {video_path}"}
56
+
57
+ @app.post(f"/{SAVE_LOG}")
58
+ async def save_log(message: str = Form(...), log_path: str = Form(...)):
59
+ """
60
+ Save a log message to a specified log file on the server.
61
+ """
62
+ # Ensure the directory for the log file exists
63
+ if os.path.dirname(log_path):
64
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
65
+
66
+ # Append the log message to the specified log file
67
+ async with aiofiles.open(log_path, mode='a') as f:
68
+ await f.write(f"{message}\n")
69
+
70
+ logger.info(f"Romote log message saved to {log_path}")
71
+ return {"message": f"Log message saved successfully to {log_path}"}
72
+
73
+
74
+ @app.get(f"/read_file")
75
+ async def read_file(file_name: str):
76
+ """
77
+ Reads the content of a specified file and returns it.
78
+ """
79
+ if not os.path.exists(file_name):
80
+ return {"message": f"File {file_name} does not exist."}
81
+
82
+ async with aiofiles.open(file_name, mode='r') as f:
83
+ content = await f.read()
84
+
85
+ logger.info(f"Read file {file_name}")
86
+ return {"file_name": file_name, "content": content}
serve/log_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities.
3
+ """
4
+ from asyncio import AbstractEventLoop
5
+ import json
6
+ import logging
7
+ import logging.handlers
8
+ import os
9
+ import platform
10
+ import sys
11
+ from typing import AsyncGenerator, Generator
12
+ import warnings
13
+ from pathlib import Path
14
+
15
+ import requests
16
+
17
+ from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
18
+ from .utils import save_log_str_on_log_server
19
+
20
+
21
+ handler = None
22
+ visited_loggers = set()
23
+
24
+
25
+ # Assuming LOGDIR and other necessary imports and global variables are defined
26
+
27
+ class APIHandler(logging.Handler):
28
+ """Custom logging handler that sends logs to an API."""
29
+
30
+ def __init__(self, apiUrl, log_path, *args, **kwargs):
31
+ super(APIHandler, self).__init__(*args, **kwargs)
32
+ self.apiUrl = apiUrl
33
+ self.log_path = log_path
34
+
35
+ def emit(self, record):
36
+ log_entry = self.format(record)
37
+ try:
38
+ save_log_str_on_log_server(log_entry, self.log_path)
39
+ except requests.RequestException as e:
40
+ print(f"Error sending log to API: {e}", file=sys.stderr)
41
+
42
+ def build_logger(logger_name, logger_filename, add_remote_handler=False):
43
+ global handler
44
+
45
+ formatter = logging.Formatter(
46
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
47
+ datefmt="%Y-%m-%d %H:%M:%S",
48
+ )
49
+
50
+ # Set the format of root handlers
51
+ if not logging.getLogger().handlers:
52
+ if sys.version_info[1] >= 9:
53
+ # This is for windows
54
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
55
+ else:
56
+ if platform.system() == "Windows":
57
+ warnings.warn(
58
+ "If you are running on Windows, "
59
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
60
+ )
61
+ logging.basicConfig(level=logging.INFO)
62
+ logging.getLogger().handlers[0].setFormatter(formatter)
63
+
64
+ # Redirect stdout and stderr to loggers
65
+ stdout_logger = logging.getLogger("stdout")
66
+ stdout_logger.setLevel(logging.INFO)
67
+ sl = StreamToLogger(stdout_logger, logging.INFO)
68
+ sys.stdout = sl
69
+
70
+ stderr_logger = logging.getLogger("stderr")
71
+ stderr_logger.setLevel(logging.ERROR)
72
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
73
+ sys.stderr = sl
74
+
75
+ # Get logger
76
+ logger = logging.getLogger(logger_name)
77
+ logger.setLevel(logging.INFO)
78
+
79
+ if add_remote_handler:
80
+ # Add APIHandler to send logs to your API
81
+ api_url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
82
+
83
+ remote_logger_filename = str(Path(logger_filename).stem + "_remote.log")
84
+ api_handler = APIHandler(apiUrl=api_url, log_path=f"{LOGDIR}/{remote_logger_filename}")
85
+ api_handler.setFormatter(formatter)
86
+ logger.addHandler(api_handler)
87
+
88
+ stdout_logger.addHandler(api_handler)
89
+ stderr_logger.addHandler(api_handler)
90
+
91
+ # if LOGDIR is empty, then don't try output log to local file
92
+ if LOGDIR != "":
93
+ os.makedirs(LOGDIR, exist_ok=True)
94
+ filename = os.path.join(LOGDIR, logger_filename)
95
+ handler = logging.handlers.TimedRotatingFileHandler(
96
+ filename, when="D", utc=True, encoding="utf-8"
97
+ )
98
+ handler.setFormatter(formatter)
99
+
100
+ for l in [stdout_logger, stderr_logger, logger]:
101
+ if l in visited_loggers:
102
+ continue
103
+ visited_loggers.add(l)
104
+ l.addHandler(handler)
105
+
106
+ return logger
107
+
108
+
109
+ class StreamToLogger(object):
110
+ """
111
+ Fake file-like stream object that redirects writes to a logger instance.
112
+ """
113
+
114
+ def __init__(self, logger, log_level=logging.INFO):
115
+ self.terminal = sys.stdout
116
+ self.logger = logger
117
+ self.log_level = log_level
118
+ self.linebuf = ""
119
+
120
+ def __getattr__(self, attr):
121
+ return getattr(self.terminal, attr)
122
+
123
+ def write(self, buf):
124
+ temp_linebuf = self.linebuf + buf
125
+ self.linebuf = ""
126
+ for line in temp_linebuf.splitlines(True):
127
+ # From the io.TextIOWrapper docs:
128
+ # On output, if newline is None, any '\n' characters written
129
+ # are translated to the system default line separator.
130
+ # By default sys.stdout.write() expects '\n' newlines and then
131
+ # translates them so this is still cross platform.
132
+ if line[-1] == "\n":
133
+ encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
134
+ self.logger.log(self.log_level, encoded_message.rstrip())
135
+ else:
136
+ self.linebuf += line
137
+
138
+ def flush(self):
139
+ if self.linebuf != "":
140
+ encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
141
+ self.logger.log(self.log_level, encoded_message.rstrip())
142
+ self.linebuf = ""