Spaces:
Runtime error
Runtime error
Bbmyy
commited on
Commit
•
850b1ec
1
Parent(s):
b42403b
Update space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-310.pyc +0 -0
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/matchmaker.cpython-310.pyc +0 -0
- model/__pycache__/model_manager.cpython-310.pyc +0 -0
- model/__pycache__/model_registry.cpython-310.pyc +0 -0
- model/__pycache__/model_registry.cpython-39.pyc +0 -0
- model/matchmaker.py +126 -0
- model/matchmaker_video.py +136 -0
- model/model_manager.py +187 -0
- model/model_registry.py +70 -0
- model/models/__init__.py +78 -0
- model/models/__pycache__/__init__.cpython-310.pyc +0 -0
- model/models/__pycache__/huggingface_models.cpython-310.pyc +0 -0
- model/models/__pycache__/openai_api_models.cpython-310.pyc +0 -0
- model/models/__pycache__/other_api_models.cpython-310.pyc +0 -0
- model/models/__pycache__/replicate_api_models.cpython-310.pyc +0 -0
- model/models/huggingface_models.py +59 -0
- model/models/openai_api_models.py +57 -0
- model/models/other_api_models.py +91 -0
- model/models/replicate_api_models.py +195 -0
- serve/Arial.ttf +0 -0
- serve/Ksort.py +411 -0
- serve/__init__.py +0 -0
- serve/__pycache__/Ksort.cpython-310.pyc +0 -0
- serve/__pycache__/Ksort.cpython-39.pyc +0 -0
- serve/__pycache__/__init__.cpython-310.pyc +0 -0
- serve/__pycache__/__init__.cpython-39.pyc +0 -0
- serve/__pycache__/constants.cpython-310.pyc +0 -0
- serve/__pycache__/constants.cpython-39.pyc +0 -0
- serve/__pycache__/gradio_web.cpython-310.pyc +0 -0
- serve/__pycache__/gradio_web.cpython-39.pyc +0 -0
- serve/__pycache__/gradio_web_bbox.cpython-310.pyc +0 -0
- serve/__pycache__/leaderboard.cpython-310.pyc +0 -0
- serve/__pycache__/log_utils.cpython-310.pyc +0 -0
- serve/__pycache__/log_utils.cpython-39.pyc +0 -0
- serve/__pycache__/update_skill.cpython-310.pyc +0 -0
- serve/__pycache__/upload.cpython-310.pyc +0 -0
- serve/__pycache__/upload.cpython-39.pyc +0 -0
- serve/__pycache__/utils.cpython-310.pyc +0 -0
- serve/__pycache__/utils.cpython-39.pyc +0 -0
- serve/__pycache__/vote_utils.cpython-310.pyc +0 -0
- serve/__pycache__/vote_utils.cpython-39.pyc +0 -0
- serve/button.css +24 -0
- serve/constants.py +63 -0
- serve/gradio_web.py +789 -0
- serve/gradio_web_bbox.py +492 -0
- serve/leaderboard.py +200 -0
- serve/log_server.py +86 -0
- serve/log_utils.py +142 -0
model/__init__.py
ADDED
File without changes
|
model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
model/__pycache__/matchmaker.cpython-310.pyc
ADDED
Binary file (4.04 kB). View file
|
|
model/__pycache__/model_manager.cpython-310.pyc
ADDED
Binary file (8.28 kB). View file
|
|
model/__pycache__/model_registry.cpython-310.pyc
ADDED
Binary file (1.64 kB). View file
|
|
model/__pycache__/model_registry.cpython-39.pyc
ADDED
Binary file (1.81 kB). View file
|
|
model/matchmaker.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
from trueskill import TrueSkill
|
4 |
+
import paramiko
|
5 |
+
import io, os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
|
9 |
+
sys.path.append('../')
|
10 |
+
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
|
11 |
+
trueskill_env = TrueSkill()
|
12 |
+
|
13 |
+
ssh_matchmaker_client = None
|
14 |
+
sftp_matchmaker_client = None
|
15 |
+
|
16 |
+
def create_ssh_matchmaker_client(server, port, user, password):
|
17 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
18 |
+
ssh_matchmaker_client = paramiko.SSHClient()
|
19 |
+
ssh_matchmaker_client.load_system_host_keys()
|
20 |
+
ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
21 |
+
ssh_matchmaker_client.connect(server, port, user, password)
|
22 |
+
|
23 |
+
transport = ssh_matchmaker_client.get_transport()
|
24 |
+
transport.set_keepalive(60)
|
25 |
+
|
26 |
+
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
27 |
+
|
28 |
+
|
29 |
+
def is_connected():
|
30 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
31 |
+
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
32 |
+
return False
|
33 |
+
if not ssh_matchmaker_client.get_transport().is_active():
|
34 |
+
return False
|
35 |
+
try:
|
36 |
+
sftp_matchmaker_client.listdir('.')
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Error checking SFTP connection: {e}")
|
39 |
+
return False
|
40 |
+
return True
|
41 |
+
|
42 |
+
|
43 |
+
def ucb_score(trueskill_diff, t, n):
|
44 |
+
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
45 |
+
ucb = -trueskill_diff + 1.0 * exploration_term
|
46 |
+
return ucb
|
47 |
+
|
48 |
+
|
49 |
+
def update_trueskill(ratings, ranks):
|
50 |
+
new_ratings = trueskill_env.rate(ratings, ranks)
|
51 |
+
return new_ratings
|
52 |
+
|
53 |
+
|
54 |
+
def serialize_rating(rating):
|
55 |
+
return {'mu': rating.mu, 'sigma': rating.sigma}
|
56 |
+
|
57 |
+
|
58 |
+
def deserialize_rating(rating_dict):
|
59 |
+
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
60 |
+
|
61 |
+
|
62 |
+
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
63 |
+
global sftp_matchmaker_client
|
64 |
+
if not is_connected():
|
65 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
66 |
+
data = {
|
67 |
+
'ratings': [serialize_rating(r) for r in ratings],
|
68 |
+
'comparison_counts': comparison_counts.tolist(),
|
69 |
+
'total_comparisons': total_comparisons
|
70 |
+
}
|
71 |
+
json_data = json.dumps(data)
|
72 |
+
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
|
73 |
+
f.write(json_data)
|
74 |
+
|
75 |
+
|
76 |
+
def load_json_via_sftp():
|
77 |
+
global sftp_matchmaker_client
|
78 |
+
if not is_connected():
|
79 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
80 |
+
with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
|
81 |
+
data = json.load(f)
|
82 |
+
ratings = [deserialize_rating(r) for r in data['ratings']]
|
83 |
+
comparison_counts = np.array(data['comparison_counts'])
|
84 |
+
total_comparisons = data['total_comparisons']
|
85 |
+
return ratings, comparison_counts, total_comparisons
|
86 |
+
|
87 |
+
|
88 |
+
class RunningPivot(object):
|
89 |
+
running_pivot = []
|
90 |
+
|
91 |
+
|
92 |
+
def matchmaker(num_players, k_group=4, not_run=[]):
|
93 |
+
trueskill_env = TrueSkill()
|
94 |
+
|
95 |
+
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
96 |
+
|
97 |
+
ratings = ratings[:num_players]
|
98 |
+
comparison_counts = comparison_counts[:num_players, :num_players]
|
99 |
+
|
100 |
+
# Randomly select a player
|
101 |
+
# selected_player = np.random.randint(0, num_players)
|
102 |
+
comparison_counts[RunningPivot.running_pivot, :] = float('inf')
|
103 |
+
comparison_counts[not_run, :] = float('inf')
|
104 |
+
selected_player = np.argmin(comparison_counts.sum(axis=1))
|
105 |
+
|
106 |
+
RunningPivot.running_pivot.append(selected_player)
|
107 |
+
RunningPivot.running_pivot = RunningPivot.running_pivot[-5:]
|
108 |
+
print(RunningPivot.running_pivot)
|
109 |
+
|
110 |
+
selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
|
111 |
+
trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
|
112 |
+
trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
|
113 |
+
n = comparison_counts[selected_player]
|
114 |
+
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
115 |
+
|
116 |
+
# Exclude self, select opponent with highest UCB score
|
117 |
+
ucb_scores[selected_player] = -float('inf')
|
118 |
+
ucb_scores[not_run] = -float('inf')
|
119 |
+
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
120 |
+
|
121 |
+
# Group players
|
122 |
+
model_ids = [selected_player] + opponents
|
123 |
+
|
124 |
+
random.shuffle(model_ids)
|
125 |
+
|
126 |
+
return model_ids
|
model/matchmaker_video.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
from trueskill import TrueSkill
|
4 |
+
import paramiko
|
5 |
+
import io, os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
|
9 |
+
sys.path.append('../')
|
10 |
+
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL
|
11 |
+
trueskill_env = TrueSkill()
|
12 |
+
|
13 |
+
ssh_matchmaker_client = None
|
14 |
+
sftp_matchmaker_client = None
|
15 |
+
|
16 |
+
|
17 |
+
def create_ssh_matchmaker_client(server, port, user, password):
|
18 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
19 |
+
ssh_matchmaker_client = paramiko.SSHClient()
|
20 |
+
ssh_matchmaker_client.load_system_host_keys()
|
21 |
+
ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
22 |
+
ssh_matchmaker_client.connect(server, port, user, password)
|
23 |
+
|
24 |
+
transport = ssh_matchmaker_client.get_transport()
|
25 |
+
transport.set_keepalive(60)
|
26 |
+
|
27 |
+
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
28 |
+
|
29 |
+
|
30 |
+
def is_connected():
|
31 |
+
global ssh_matchmaker_client, sftp_matchmaker_client
|
32 |
+
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
33 |
+
return False
|
34 |
+
if not ssh_matchmaker_client.get_transport().is_active():
|
35 |
+
return False
|
36 |
+
try:
|
37 |
+
sftp_matchmaker_client.listdir('.')
|
38 |
+
except Exception as e:
|
39 |
+
print(f"Error checking SFTP connection: {e}")
|
40 |
+
return False
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
def ucb_score(trueskill_diff, t, n):
|
45 |
+
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
46 |
+
ucb = -trueskill_diff + 1.0 * exploration_term
|
47 |
+
return ucb
|
48 |
+
|
49 |
+
|
50 |
+
def update_trueskill(ratings, ranks):
|
51 |
+
new_ratings = trueskill_env.rate(ratings, ranks)
|
52 |
+
return new_ratings
|
53 |
+
|
54 |
+
|
55 |
+
def serialize_rating(rating):
|
56 |
+
return {'mu': rating.mu, 'sigma': rating.sigma}
|
57 |
+
|
58 |
+
|
59 |
+
def deserialize_rating(rating_dict):
|
60 |
+
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
61 |
+
|
62 |
+
|
63 |
+
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
64 |
+
global sftp_matchmaker_client
|
65 |
+
if not is_connected():
|
66 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
67 |
+
data = {
|
68 |
+
'ratings': [serialize_rating(r) for r in ratings],
|
69 |
+
'comparison_counts': comparison_counts.tolist(),
|
70 |
+
'total_comparisons': total_comparisons
|
71 |
+
}
|
72 |
+
json_data = json.dumps(data)
|
73 |
+
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
74 |
+
f.write(json_data)
|
75 |
+
|
76 |
+
|
77 |
+
def load_json_via_sftp():
|
78 |
+
global sftp_matchmaker_client
|
79 |
+
if not is_connected():
|
80 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
81 |
+
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'r') as f:
|
82 |
+
data = json.load(f)
|
83 |
+
ratings = [deserialize_rating(r) for r in data['ratings']]
|
84 |
+
comparison_counts = np.array(data['comparison_counts'])
|
85 |
+
total_comparisons = data['total_comparisons']
|
86 |
+
return ratings, comparison_counts, total_comparisons
|
87 |
+
|
88 |
+
|
89 |
+
def matchmaker_video(num_players, k_group=4):
|
90 |
+
trueskill_env = TrueSkill()
|
91 |
+
|
92 |
+
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
|
93 |
+
|
94 |
+
ratings = ratings[:num_players]
|
95 |
+
comparison_counts = comparison_counts[:num_players, :num_players]
|
96 |
+
|
97 |
+
selected_player = np.argmin(comparison_counts.sum(axis=1))
|
98 |
+
|
99 |
+
selected_trueskill_score = trueskill_env.expose(ratings[selected_player])
|
100 |
+
trueskill_scores = np.array([trueskill_env.expose(p) for p in ratings])
|
101 |
+
trueskill_diff = np.abs(trueskill_scores - selected_trueskill_score)
|
102 |
+
n = comparison_counts[selected_player]
|
103 |
+
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
104 |
+
|
105 |
+
# Exclude self, select opponent with highest UCB score
|
106 |
+
ucb_scores[selected_player] = -float('inf')
|
107 |
+
|
108 |
+
excluded_players_1 = [7, 10]
|
109 |
+
excluded_players_2 = [6, 8, 9]
|
110 |
+
excluded_players = excluded_players_1 + excluded_players_2
|
111 |
+
if selected_player in excluded_players_1:
|
112 |
+
for player in excluded_players:
|
113 |
+
ucb_scores[player] = -float('inf')
|
114 |
+
if selected_player in excluded_players_2:
|
115 |
+
for player in excluded_players_1:
|
116 |
+
ucb_scores[player] = -float('inf')
|
117 |
+
else:
|
118 |
+
excluded_ucb_scores = {player: ucb_scores[player] for player in excluded_players}
|
119 |
+
max_player = max(excluded_ucb_scores, key=excluded_ucb_scores.get)
|
120 |
+
if max_player in excluded_players_1:
|
121 |
+
for player in excluded_players:
|
122 |
+
if player != max_player:
|
123 |
+
ucb_scores[player] = -float('inf')
|
124 |
+
else:
|
125 |
+
for player in excluded_players_1:
|
126 |
+
ucb_scores[player] = -float('inf')
|
127 |
+
|
128 |
+
|
129 |
+
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
130 |
+
|
131 |
+
# Group players
|
132 |
+
model_ids = [selected_player] + opponents
|
133 |
+
|
134 |
+
random.shuffle(model_ids)
|
135 |
+
|
136 |
+
return model_ids
|
model/model_manager.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent.futures
|
2 |
+
import random
|
3 |
+
import gradio as gr
|
4 |
+
import requests, os
|
5 |
+
import io, base64, json
|
6 |
+
import spaces
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from openai import OpenAI
|
10 |
+
from .models import IMAGE_GENERATION_MODELS, VIDEO_GENERATION_MODELS, load_pipeline
|
11 |
+
from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt
|
12 |
+
from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
|
13 |
+
|
14 |
+
|
15 |
+
class ModelManager:
|
16 |
+
def __init__(self):
|
17 |
+
self.model_ig_list = IMAGE_GENERATION_MODELS
|
18 |
+
self.model_ie_list = [] #IMAGE_EDITION_MODELS
|
19 |
+
self.model_vg_list = VIDEO_GENERATION_MODELS
|
20 |
+
self.model_b2i_list = []
|
21 |
+
self.loaded_models = {}
|
22 |
+
|
23 |
+
def load_model_pipe(self, model_name):
|
24 |
+
if not model_name in self.loaded_models:
|
25 |
+
pipe = load_pipeline(model_name)
|
26 |
+
self.loaded_models[model_name] = pipe
|
27 |
+
else:
|
28 |
+
pipe = self.loaded_models[model_name]
|
29 |
+
return pipe
|
30 |
+
|
31 |
+
@spaces.GPU(duration=120)
|
32 |
+
def generate_image_ig(self, prompt, model_name):
|
33 |
+
pipe = self.load_model_pipe(model_name)
|
34 |
+
if 'Stable-cascade' not in model_name:
|
35 |
+
result = pipe(prompt=prompt).images[0]
|
36 |
+
else:
|
37 |
+
prior, decoder = pipe
|
38 |
+
prior.enable_model_cpu_offload()
|
39 |
+
prior_output = prior(
|
40 |
+
prompt=prompt,
|
41 |
+
height=512,
|
42 |
+
width=512,
|
43 |
+
negative_prompt='',
|
44 |
+
guidance_scale=4.0,
|
45 |
+
num_images_per_prompt=1,
|
46 |
+
num_inference_steps=20
|
47 |
+
)
|
48 |
+
decoder.enable_model_cpu_offload()
|
49 |
+
result = decoder(
|
50 |
+
image_embeddings=prior_output.image_embeddings.to(torch.float16),
|
51 |
+
prompt=prompt,
|
52 |
+
negative_prompt='',
|
53 |
+
guidance_scale=0.0,
|
54 |
+
output_type="pil",
|
55 |
+
num_inference_steps=10
|
56 |
+
).images[0]
|
57 |
+
return result
|
58 |
+
|
59 |
+
def generate_image_ig_api(self, prompt, model_name):
|
60 |
+
pipe = self.load_model_pipe(model_name)
|
61 |
+
result = pipe(prompt=prompt)
|
62 |
+
return result
|
63 |
+
|
64 |
+
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
65 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
66 |
+
from .matchmaker import matchmaker
|
67 |
+
not_run = [20,21,22, 25,26, 30] #12,13,14,15,16,17,18,19,20,21,22, #23,24,
|
68 |
+
model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
69 |
+
print(model_ids)
|
70 |
+
model_names = [self.model_ig_list[i] for i in model_ids]
|
71 |
+
print(model_names)
|
72 |
+
else:
|
73 |
+
model_names = [model_A, model_B, model_C, model_D]
|
74 |
+
|
75 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
76 |
+
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
77 |
+
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
78 |
+
results = [future.result() for future in futures]
|
79 |
+
|
80 |
+
return results[0], results[1], results[2], results[3], \
|
81 |
+
model_names[0], model_names[1], model_names[2], model_names[3]
|
82 |
+
|
83 |
+
def generate_image_ig_cache_anony(self, model_A, model_B, model_C, model_D):
|
84 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
85 |
+
from .matchmaker import matchmaker
|
86 |
+
not_run = [20,21,22]
|
87 |
+
model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
|
88 |
+
print(model_ids)
|
89 |
+
model_names = [self.model_ig_list[i] for i in model_ids]
|
90 |
+
print(model_names)
|
91 |
+
else:
|
92 |
+
model_names = [model_A, model_B, model_C, model_D]
|
93 |
+
|
94 |
+
root_dir = SSH_CACHE_IMAGE
|
95 |
+
local_dir = "./cache_image"
|
96 |
+
if not os.path.exists(local_dir):
|
97 |
+
os.makedirs(local_dir)
|
98 |
+
prompt, results = get_ssh_random_image_prompt(root_dir, local_dir, model_names)
|
99 |
+
|
100 |
+
return results[0], results[1], results[2], results[3], \
|
101 |
+
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
102 |
+
|
103 |
+
def generate_video_vg_parallel_anony(self, model_A, model_B, model_C, model_D):
|
104 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
105 |
+
# model_names = random.sample([model for model in self.model_vg_list], 4)
|
106 |
+
|
107 |
+
from .matchmaker_video import matchmaker_video
|
108 |
+
model_ids = matchmaker_video(num_players=len(self.model_vg_list))
|
109 |
+
print(model_ids)
|
110 |
+
model_names = [self.model_vg_list[i] for i in model_ids]
|
111 |
+
print(model_names)
|
112 |
+
else:
|
113 |
+
model_names = [model_A, model_B, model_C, model_D]
|
114 |
+
|
115 |
+
root_dir = SSH_CACHE_OPENSOURCE
|
116 |
+
for name in model_names:
|
117 |
+
if "Runway-Gen3" in name or "Runway-Gen2" in name or "Pika-v1.0" in name:
|
118 |
+
root_dir = SSH_CACHE_ADVANCE
|
119 |
+
elif "Pika-beta" in name:
|
120 |
+
root_dir = SSH_CACHE_PIKA
|
121 |
+
elif "Sora" in name and "OpenSora" not in name:
|
122 |
+
root_dir = SSH_CACHE_SORA
|
123 |
+
|
124 |
+
local_dir = "./cache_video"
|
125 |
+
if not os.path.exists(local_dir):
|
126 |
+
os.makedirs(local_dir)
|
127 |
+
prompt, results = get_ssh_random_video_prompt(root_dir, local_dir, model_names)
|
128 |
+
cache_dir = local_dir
|
129 |
+
|
130 |
+
return results[0], results[1], results[2], results[3], \
|
131 |
+
model_names[0], model_names[1], model_names[2], model_names[3], prompt, cache_dir
|
132 |
+
|
133 |
+
def generate_image_ig_museum_parallel_anony(self, model_A, model_B, model_C, model_D):
|
134 |
+
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
135 |
+
# model_names = random.sample([model for model in self.model_ig_list], 4)
|
136 |
+
|
137 |
+
from .matchmaker import matchmaker
|
138 |
+
model_ids = matchmaker(num_players=len(self.model_ig_list))
|
139 |
+
print(model_ids)
|
140 |
+
model_names = [self.model_ig_list[i] for i in model_ids]
|
141 |
+
print(model_names)
|
142 |
+
else:
|
143 |
+
model_names = [model_A, model_B, model_C, model_D]
|
144 |
+
|
145 |
+
prompt = get_random_mscoco_prompt()
|
146 |
+
print(prompt)
|
147 |
+
|
148 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
149 |
+
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
150 |
+
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
151 |
+
results = [future.result() for future in futures]
|
152 |
+
|
153 |
+
return results[0], results[1], results[2], results[3], \
|
154 |
+
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
155 |
+
|
156 |
+
def generate_image_ig_parallel(self, prompt, model_A, model_B):
|
157 |
+
model_names = [model_A, model_B]
|
158 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
159 |
+
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("imagenhub")
|
160 |
+
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
161 |
+
results = [future.result() for future in futures]
|
162 |
+
return results[0], results[1]
|
163 |
+
|
164 |
+
@spaces.GPU(duration=200)
|
165 |
+
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
166 |
+
pipe = self.load_model_pipe(model_name)
|
167 |
+
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
168 |
+
return result
|
169 |
+
|
170 |
+
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
171 |
+
model_names = [model_A, model_B]
|
172 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
173 |
+
futures = [
|
174 |
+
executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image,
|
175 |
+
model) for model in model_names]
|
176 |
+
results = [future.result() for future in futures]
|
177 |
+
return results[0], results[1]
|
178 |
+
|
179 |
+
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
180 |
+
if model_A == "" and model_B == "":
|
181 |
+
model_names = random.sample([model for model in self.model_ie_list], 2)
|
182 |
+
else:
|
183 |
+
model_names = [model_A, model_B]
|
184 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
185 |
+
futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names]
|
186 |
+
results = [future.result() for future in futures]
|
187 |
+
return results[0], results[1], model_names[0], model_names[1]
|
model/model_registry.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
|
5 |
+
model_info = {}
|
6 |
+
|
7 |
+
def register_model_info(
|
8 |
+
full_names: List[str], simple_name: str, link: str, description: str
|
9 |
+
):
|
10 |
+
info = ModelInfo(simple_name, link, description)
|
11 |
+
|
12 |
+
for full_name in full_names:
|
13 |
+
model_info[full_name] = info
|
14 |
+
|
15 |
+
def get_model_info(name: str) -> ModelInfo:
|
16 |
+
if name in model_info:
|
17 |
+
return model_info[name]
|
18 |
+
else:
|
19 |
+
# To fix this, please use `register_model_info` to register your model
|
20 |
+
return ModelInfo(
|
21 |
+
name, "", "Register the description at fastchat/model/model_registry.py"
|
22 |
+
)
|
23 |
+
|
24 |
+
def get_model_description_md(model_list):
|
25 |
+
model_description_md = """
|
26 |
+
| | | | | | | | | | | |
|
27 |
+
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
|
28 |
+
"""
|
29 |
+
ct = 0
|
30 |
+
visited = set()
|
31 |
+
for i, name in enumerate(model_list):
|
32 |
+
model_source, model_name, model_type = name.split("_")
|
33 |
+
minfo = get_model_info(model_name)
|
34 |
+
if minfo.simple_name in visited:
|
35 |
+
continue
|
36 |
+
visited.add(minfo.simple_name)
|
37 |
+
# one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
38 |
+
one_model_md = f"{minfo.simple_name}"
|
39 |
+
|
40 |
+
if ct % 11 == 0:
|
41 |
+
model_description_md += "|"
|
42 |
+
model_description_md += f" {one_model_md} |"
|
43 |
+
if ct % 11 == 10:
|
44 |
+
model_description_md += "\n"
|
45 |
+
ct += 1
|
46 |
+
return model_description_md
|
47 |
+
|
48 |
+
def get_video_model_description_md(model_list):
|
49 |
+
model_description_md = """
|
50 |
+
| | | | | | |
|
51 |
+
| ---- | ---- | ---- | ---- | ---- | ---- |
|
52 |
+
"""
|
53 |
+
ct = 0
|
54 |
+
visited = set()
|
55 |
+
for i, name in enumerate(model_list):
|
56 |
+
model_source, model_name, model_type = name.split("_")
|
57 |
+
minfo = get_model_info(model_name)
|
58 |
+
if minfo.simple_name in visited:
|
59 |
+
continue
|
60 |
+
visited.add(minfo.simple_name)
|
61 |
+
# one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
62 |
+
one_model_md = f"{minfo.simple_name}"
|
63 |
+
|
64 |
+
if ct % 7 == 0:
|
65 |
+
model_description_md += "|"
|
66 |
+
model_description_md += f" {one_model_md} |"
|
67 |
+
if ct % 7 == 6:
|
68 |
+
model_description_md += "\n"
|
69 |
+
ct += 1
|
70 |
+
return model_description_md
|
model/models/__init__.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .huggingface_models import load_huggingface_model
|
2 |
+
from .replicate_api_models import load_replicate_model
|
3 |
+
from .openai_api_models import load_openai_model
|
4 |
+
from .other_api_models import load_other_model
|
5 |
+
|
6 |
+
|
7 |
+
IMAGE_GENERATION_MODELS = [
|
8 |
+
'replicate_SDXL_text2image',
|
9 |
+
'replicate_SD-v3.0_text2image',
|
10 |
+
'replicate_SD-v2.1_text2image',
|
11 |
+
'replicate_SD-v1.5_text2image',
|
12 |
+
'replicate_SDXL-Lightning_text2image',
|
13 |
+
'replicate_Kandinsky-v2.0_text2image',
|
14 |
+
'replicate_Kandinsky-v2.2_text2image',
|
15 |
+
'replicate_Proteus-v0.2_text2image',
|
16 |
+
'replicate_Playground-v2.0_text2image',
|
17 |
+
'replicate_Playground-v2.5_text2image',
|
18 |
+
'replicate_Dreamshaper-xl-turbo_text2image',
|
19 |
+
'replicate_SDXL-Deepcache_text2image',
|
20 |
+
'replicate_Openjourney-v4_text2image',
|
21 |
+
'replicate_LCM-v1.5_text2image',
|
22 |
+
'replicate_Realvisxl-v3.0_text2image',
|
23 |
+
'replicate_Realvisxl-v2.0_text2image',
|
24 |
+
'replicate_Pixart-Sigma_text2image',
|
25 |
+
'replicate_SSD-1b_text2image',
|
26 |
+
'replicate_Open-Dalle-v1.1_text2image',
|
27 |
+
'replicate_Deepfloyd-IF_text2image',
|
28 |
+
'huggingface_SD-turbo_text2image',
|
29 |
+
'huggingface_SDXL-turbo_text2image',
|
30 |
+
'huggingface_Stable-cascade_text2image',
|
31 |
+
'openai_Dalle-2_text2image',
|
32 |
+
'openai_Dalle-3_text2image',
|
33 |
+
'other_Midjourney-v6.0_text2image',
|
34 |
+
'other_Midjourney-v5.0_text2image',
|
35 |
+
"replicate_FLUX.1-schnell_text2image",
|
36 |
+
"replicate_FLUX.1-pro_text2image",
|
37 |
+
"replicate_FLUX.1-dev_text2image",
|
38 |
+
'other_Meissonic_text2image',
|
39 |
+
"replicate_FLUX-1.1-pro_text2image",
|
40 |
+
'replicate_SD-v3.5-large_text2image',
|
41 |
+
'replicate_SD-v3.5-large-turbo_text2image',
|
42 |
+
]
|
43 |
+
|
44 |
+
VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
45 |
+
'replicate_Animate-Diff_text2video',
|
46 |
+
'replicate_OpenSora_text2video',
|
47 |
+
'replicate_LaVie_text2video',
|
48 |
+
'replicate_VideoCrafter2_text2video',
|
49 |
+
'replicate_Stable-Video-Diffusion_text2video',
|
50 |
+
'other_Runway-Gen3_text2video',
|
51 |
+
'other_Pika-beta_text2video',
|
52 |
+
'other_Pika-v1.0_text2video',
|
53 |
+
'other_Runway-Gen2_text2video',
|
54 |
+
'other_Sora_text2video',
|
55 |
+
'replicate_Cogvideox-5b_text2video',
|
56 |
+
'other_KLing-v1.0_text2video',
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
def load_pipeline(model_name):
|
61 |
+
"""
|
62 |
+
Load a model pipeline based on the model name
|
63 |
+
Args:
|
64 |
+
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
|
65 |
+
"""
|
66 |
+
model_source, model_name, model_type = model_name.split("_")
|
67 |
+
|
68 |
+
if model_source == "replicate":
|
69 |
+
pipe = load_replicate_model(model_name, model_type)
|
70 |
+
elif model_source == "huggingface":
|
71 |
+
pipe = load_huggingface_model(model_name, model_type)
|
72 |
+
elif model_source == "openai":
|
73 |
+
pipe = load_openai_model(model_name, model_type)
|
74 |
+
elif model_source == "other":
|
75 |
+
pipe = load_other_model(model_name, model_type)
|
76 |
+
else:
|
77 |
+
raise ValueError(f"Model source {model_source} not supported")
|
78 |
+
return pipe
|
model/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.64 kB). View file
|
|
model/models/__pycache__/huggingface_models.cpython-310.pyc
ADDED
Binary file (1.48 kB). View file
|
|
model/models/__pycache__/openai_api_models.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
model/models/__pycache__/other_api_models.cpython-310.pyc
ADDED
Binary file (2.56 kB). View file
|
|
model/models/__pycache__/replicate_api_models.cpython-310.pyc
ADDED
Binary file (6.31 kB). View file
|
|
model/models/huggingface_models.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
from diffusers import AutoPipelineForText2Image
|
3 |
+
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def load_huggingface_model(model_name, model_type):
|
8 |
+
if model_name == "SD-turbo":
|
9 |
+
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
|
10 |
+
pipe = pipe.to("cuda")
|
11 |
+
elif model_name == "SDXL-turbo":
|
12 |
+
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
13 |
+
pipe = pipe.to("cuda")
|
14 |
+
elif model_name == "Stable-cascade":
|
15 |
+
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
|
16 |
+
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
|
17 |
+
pipe = [prior, decoder]
|
18 |
+
else:
|
19 |
+
raise NotImplementedError
|
20 |
+
# if model_name == "SD-turbo":
|
21 |
+
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
|
22 |
+
# elif model_name == "SDXL-turbo":
|
23 |
+
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
|
24 |
+
# else:
|
25 |
+
# raise NotImplementedError
|
26 |
+
# pipe = pipe.to("cpu")
|
27 |
+
return pipe
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
# for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
|
32 |
+
# pipe = load_huggingface_model(name, "text2image")
|
33 |
+
|
34 |
+
# for name in ["IF-I-XL-v1.0"]:
|
35 |
+
# pipe = load_huggingface_model(name, 'text2image')
|
36 |
+
# pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
37 |
+
|
38 |
+
prompt = 'draw a tiger'
|
39 |
+
pipe = load_huggingface_model('Stable-cascade', "text2image")
|
40 |
+
prior, decoder = pipe
|
41 |
+
prior.enable_model_cpu_offload()
|
42 |
+
prior_output = prior(
|
43 |
+
prompt=prompt,
|
44 |
+
height=512,
|
45 |
+
width=512,
|
46 |
+
negative_prompt='',
|
47 |
+
guidance_scale=4.0,
|
48 |
+
num_images_per_prompt=1,
|
49 |
+
num_inference_steps=20
|
50 |
+
)
|
51 |
+
decoder.enable_model_cpu_offload()
|
52 |
+
result = decoder(
|
53 |
+
image_embeddings=prior_output.image_embeddings.to(torch.float16),
|
54 |
+
prompt=prompt,
|
55 |
+
negative_prompt='',
|
56 |
+
guidance_scale=0.0,
|
57 |
+
output_type="pil",
|
58 |
+
num_inference_steps=10
|
59 |
+
).images[0]
|
model/models/openai_api_models.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
|
8 |
+
|
9 |
+
class OpenaiModel():
|
10 |
+
def __init__(self, model_name, model_type):
|
11 |
+
self.model_name = model_name
|
12 |
+
self.model_type = model_type
|
13 |
+
|
14 |
+
def __call__(self, *args, **kwargs):
|
15 |
+
if self.model_type == "text2image":
|
16 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
17 |
+
|
18 |
+
client = OpenAI()
|
19 |
+
|
20 |
+
if 'Dalle-3' in self.model_name:
|
21 |
+
client = OpenAI()
|
22 |
+
response = client.images.generate(
|
23 |
+
model="dall-e-3",
|
24 |
+
prompt=kwargs["prompt"],
|
25 |
+
size="1024x1024",
|
26 |
+
quality="standard",
|
27 |
+
n=1,
|
28 |
+
)
|
29 |
+
elif 'Dalle-2' in self.model_name:
|
30 |
+
client = OpenAI()
|
31 |
+
response = client.images.generate(
|
32 |
+
model="dall-e-2",
|
33 |
+
prompt=kwargs["prompt"],
|
34 |
+
size="512x512",
|
35 |
+
quality="standard",
|
36 |
+
n=1,
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
result_url = response.data[0].url
|
42 |
+
response = requests.get(result_url)
|
43 |
+
result = Image.open(io.BytesIO(response.content))
|
44 |
+
return result
|
45 |
+
else:
|
46 |
+
raise ValueError("model_type must be text2image or image2image")
|
47 |
+
|
48 |
+
|
49 |
+
def load_openai_model(model_name, model_type):
|
50 |
+
return OpenaiModel(model_name, model_type)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
pipe = load_openai_model('Dalle-3', 'text2image')
|
55 |
+
result = pipe(prompt='draw a tiger')
|
56 |
+
print(result)
|
57 |
+
|
model/models/other_api_models.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import io, time
|
6 |
+
|
7 |
+
|
8 |
+
class OtherModel():
|
9 |
+
def __init__(self, model_name, model_type):
|
10 |
+
self.model_name = model_name
|
11 |
+
self.model_type = model_type
|
12 |
+
self.image_url = "https://www.xdai.online/mj/submit/imagine"
|
13 |
+
self.key = os.environ.get('MIDJOURNEY_KEY')
|
14 |
+
self.get_image_url = "https://www.xdai.online/mj/image/"
|
15 |
+
self.repeat_num = 5
|
16 |
+
|
17 |
+
def __call__(self, *args, **kwargs):
|
18 |
+
if self.model_type == "text2image":
|
19 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
20 |
+
if self.model_name == "Midjourney-v6.0":
|
21 |
+
data = {
|
22 |
+
"base64Array": [],
|
23 |
+
"notifyHook": "",
|
24 |
+
"prompt": "{} --v 6.0".format(kwargs["prompt"]),
|
25 |
+
"state": "",
|
26 |
+
"botType": "MID_JOURNEY",
|
27 |
+
}
|
28 |
+
elif self.model_name == "Midjourney-v5.0":
|
29 |
+
data = {
|
30 |
+
"base64Array": [],
|
31 |
+
"notifyHook": "",
|
32 |
+
"prompt": "{} --v 5.0".format(kwargs["prompt"]),
|
33 |
+
"state": "",
|
34 |
+
"botType": "MID_JOURNEY",
|
35 |
+
}
|
36 |
+
else:
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
headers = {
|
40 |
+
"Authorization": "Bearer {}".format(self.key),
|
41 |
+
"Content-Type": "application/json"
|
42 |
+
}
|
43 |
+
while 1:
|
44 |
+
response = requests.post(self.image_url, data=json.dumps(data), headers=headers)
|
45 |
+
if response.status_code == 200:
|
46 |
+
print("Submit success!")
|
47 |
+
response_json = json.loads(response.content.decode('utf-8'))
|
48 |
+
img_id = response_json["result"]
|
49 |
+
result_url = self.get_image_url + img_id
|
50 |
+
print(result_url)
|
51 |
+
self.repeat_num = 800
|
52 |
+
while 1:
|
53 |
+
time.sleep(1)
|
54 |
+
img_response = requests.get(result_url)
|
55 |
+
if img_response.status_code == 200:
|
56 |
+
result = Image.open(io.BytesIO(img_response.content))
|
57 |
+
width, height = result.size
|
58 |
+
new_width = width // 2
|
59 |
+
new_height = height // 2
|
60 |
+
result = result.crop((0, 0, new_width, new_height))
|
61 |
+
self.repeat_num = 5
|
62 |
+
return result
|
63 |
+
else:
|
64 |
+
self.repeat_num = self.repeat_num - 1
|
65 |
+
if self.repeat_num == 0:
|
66 |
+
raise ValueError("Image request failed.")
|
67 |
+
continue
|
68 |
+
|
69 |
+
else:
|
70 |
+
self.repeat_num = self.repeat_num - 1
|
71 |
+
if self.repeat_num == 0:
|
72 |
+
raise ValueError("API request failed.")
|
73 |
+
continue
|
74 |
+
if self.model_type == "text2video":
|
75 |
+
assert "prompt" in kwargs, "prompt is required for text2video model"
|
76 |
+
|
77 |
+
else:
|
78 |
+
raise ValueError("model_type must be text2image")
|
79 |
+
|
80 |
+
|
81 |
+
def load_other_model(model_name, model_type):
|
82 |
+
return OtherModel(model_name, model_type)
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
import http.client
|
86 |
+
import json
|
87 |
+
|
88 |
+
pipe = load_other_model("Midjourney-v5.0", "text2image")
|
89 |
+
result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow")
|
90 |
+
print(result)
|
91 |
+
exit()
|
model/models/replicate_api_models.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import replicate
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
|
8 |
+
Replicate_MODEl_NAME_MAP = {
|
9 |
+
"SDXL": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
|
10 |
+
"SD-v3.0": "stability-ai/stable-diffusion-3",
|
11 |
+
"SD-v2.1": "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
|
12 |
+
"SD-v1.5": "stability-ai/stable-diffusion:b3d14e1cd1f9470bbb0bb68cac48e5f483e5be309551992cc33dc30654a82bb7",
|
13 |
+
"SDXL-Lightning": "bytedance/sdxl-lightning-4step:5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
|
14 |
+
"Kandinsky-v2.0": "ai-forever/kandinsky-2:3c6374e7a9a17e01afe306a5218cc67de55b19ea536466d6ea2602cfecea40a9",
|
15 |
+
"Kandinsky-v2.2": "ai-forever/kandinsky-2.2:ad9d7879fbffa2874e1d909d1d37d9bc682889cc65b31f7bb00d2362619f194a",
|
16 |
+
"Proteus-v0.2": "lucataco/proteus-v0.2:06775cd262843edbde5abab958abdbb65a0a6b58ca301c9fd78fa55c775fc019",
|
17 |
+
"Playground-v2.0": "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8",
|
18 |
+
"Playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
|
19 |
+
"Dreamshaper-xl-turbo": "lucataco/dreamshaper-xl-turbo:0a1710e0187b01a255302738ca0158ff02a22f4638679533e111082f9dd1b615",
|
20 |
+
"SDXL-Deepcache": "lucataco/sdxl-deepcache:eaf678fb34006669e9a3c6dd5971e2279bf20ee0adeced464d7b6d95de16dc93",
|
21 |
+
"Openjourney-v4": "prompthero/openjourney:ad59ca21177f9e217b9075e7300cf6e14f7e5b4505b87b9689dbd866e9768969",
|
22 |
+
"LCM-v1.5": "fofr/latent-consistency-model:683d19dc312f7a9f0428b04429a9ccefd28dbf7785fef083ad5cf991b65f406f",
|
23 |
+
"Realvisxl-v3.0": "fofr/realvisxl-v3:33279060bbbb8858700eb2146350a98d96ef334fcf817f37eb05915e1534aa1c",
|
24 |
+
|
25 |
+
"Realvisxl-v2.0": "lucataco/realvisxl-v2.0:7d6a2f9c4754477b12c14ed2a58f89bb85128edcdd581d24ce58b6926029de08",
|
26 |
+
"Pixart-Sigma": "cjwbw/pixart-sigma:5a54352c99d9fef467986bc8f3a20205e8712cbd3df1cbae4975d6254c902de1",
|
27 |
+
"SSD-1b": "lucataco/ssd-1b:b19e3639452c59ce8295b82aba70a231404cb062f2eb580ea894b31e8ce5bbb6",
|
28 |
+
"Open-Dalle-v1.1": "lucataco/open-dalle-v1.1:1c7d4c8dec39c7306df7794b28419078cb9d18b9213ab1c21fdc46a1deca0144",
|
29 |
+
"Deepfloyd-IF": "andreasjansson/deepfloyd-if:fb84d659df149f4515c351e394d22222a94144aa1403870c36025c8b28846c8d",
|
30 |
+
|
31 |
+
"Zeroscope-v2-xl": "anotherjesse/zeroscope-v2-xl:9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351",
|
32 |
+
# "Damo-Text-to-Video": "cjwbw/damo-text-to-video:1e205ea73084bd17a0a3b43396e49ba0d6bc2e754e9283b2df49fad2dcf95755",
|
33 |
+
"Animate-Diff": "lucataco/animate-diff:beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f",
|
34 |
+
"OpenSora": "camenduru/open-sora:8099e5722ba3d5f408cd3e696e6df058137056268939337a3fbe3912e86e72ad",
|
35 |
+
"LaVie": "cjwbw/lavie:0bca850c4928b6c30052541fa002f24cbb4b677259c461dd041d271ba9d3c517",
|
36 |
+
"VideoCrafter2": "lucataco/video-crafter:7757c5775e962c618053e7df4343052a21075676d6234e8ede5fa67c9e43bce0",
|
37 |
+
"Stable-Video-Diffusion": "sunfjun/stable-video-diffusion:d68b6e09eedbac7a49e3d8644999d93579c386a083768235cabca88796d70d82",
|
38 |
+
"FLUX.1-schnell": "black-forest-labs/flux-schnell",
|
39 |
+
"FLUX.1-pro": "black-forest-labs/flux-pro",
|
40 |
+
"FLUX.1-dev": "black-forest-labs/flux-dev",
|
41 |
+
"FLUX-1.1-pro": "black-forest-labs/flux-1.1-pro",
|
42 |
+
"SD-v3.5-large": "stability-ai/stable-diffusion-3.5-large",
|
43 |
+
"SD-v3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
class ReplicateModel():
|
48 |
+
def __init__(self, model_name, model_type):
|
49 |
+
self.model_name = model_name
|
50 |
+
self.model_type = model_type
|
51 |
+
|
52 |
+
def __call__(self, *args, **kwargs):
|
53 |
+
if self.model_type == "text2image":
|
54 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
55 |
+
output = replicate.run(
|
56 |
+
f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
|
57 |
+
input={
|
58 |
+
"width": 512,
|
59 |
+
"height": 512,
|
60 |
+
"prompt": kwargs["prompt"]
|
61 |
+
},
|
62 |
+
)
|
63 |
+
if 'Openjourney' in self.model_name:
|
64 |
+
for item in output:
|
65 |
+
result_url = item
|
66 |
+
break
|
67 |
+
elif isinstance(output, list):
|
68 |
+
result_url = output[0]
|
69 |
+
else:
|
70 |
+
result_url = output
|
71 |
+
print(self.model_name, result_url)
|
72 |
+
response = requests.get(result_url)
|
73 |
+
result = Image.open(io.BytesIO(response.content))
|
74 |
+
return result
|
75 |
+
|
76 |
+
elif self.model_type == "text2video":
|
77 |
+
assert "prompt" in kwargs, "prompt is required for text2image model"
|
78 |
+
if self.model_name == "Zeroscope-v2-xl":
|
79 |
+
input = {
|
80 |
+
"fps": 24,
|
81 |
+
"width": 512,
|
82 |
+
"height": 512,
|
83 |
+
"prompt": kwargs["prompt"],
|
84 |
+
"guidance_scale": 17.5,
|
85 |
+
# "negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken",
|
86 |
+
"num_frames": 48,
|
87 |
+
}
|
88 |
+
elif self.model_name == "Damo-Text-to-Video":
|
89 |
+
input={
|
90 |
+
"fps": 8,
|
91 |
+
"prompt": kwargs["prompt"],
|
92 |
+
"num_frames": 16,
|
93 |
+
"num_inference_steps": 50
|
94 |
+
}
|
95 |
+
elif self.model_name == "Animate-Diff":
|
96 |
+
input={
|
97 |
+
"path": "toonyou_beta3.safetensors",
|
98 |
+
"seed": 255224557,
|
99 |
+
"steps": 25,
|
100 |
+
"prompt": kwargs["prompt"],
|
101 |
+
"n_prompt": "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth",
|
102 |
+
"motion_module": "mm_sd_v14",
|
103 |
+
"guidance_scale": 7.5
|
104 |
+
}
|
105 |
+
elif self.model_name == "OpenSora":
|
106 |
+
input={
|
107 |
+
"seed": 1234,
|
108 |
+
"prompt": kwargs["prompt"],
|
109 |
+
}
|
110 |
+
elif self.model_name == "LaVie":
|
111 |
+
input={
|
112 |
+
"width": 512,
|
113 |
+
"height": 512,
|
114 |
+
"prompt": kwargs["prompt"],
|
115 |
+
"quality": 9,
|
116 |
+
"video_fps": 8,
|
117 |
+
"interpolation": False,
|
118 |
+
"sample_method": "ddpm",
|
119 |
+
"guidance_scale": 7,
|
120 |
+
"super_resolution": False,
|
121 |
+
"num_inference_steps": 50
|
122 |
+
}
|
123 |
+
elif self.model_name == "VideoCrafter2":
|
124 |
+
input={
|
125 |
+
"fps": 24,
|
126 |
+
"seed": 64045,
|
127 |
+
"steps": 40,
|
128 |
+
"width": 512,
|
129 |
+
"height": 512,
|
130 |
+
"prompt": kwargs["prompt"],
|
131 |
+
}
|
132 |
+
elif self.model_name == "Stable-Video-Diffusion":
|
133 |
+
text2image_name = "SD-v2.1"
|
134 |
+
output = replicate.run(
|
135 |
+
f"{Replicate_MODEl_NAME_MAP[text2image_name]}",
|
136 |
+
input={
|
137 |
+
"width": 512,
|
138 |
+
"height": 512,
|
139 |
+
"prompt": kwargs["prompt"]
|
140 |
+
},
|
141 |
+
)
|
142 |
+
if isinstance(output, list):
|
143 |
+
image_url = output[0]
|
144 |
+
else:
|
145 |
+
image_url = output
|
146 |
+
print(image_url)
|
147 |
+
|
148 |
+
input={
|
149 |
+
"cond_aug": 0.02,
|
150 |
+
"decoding_t": 14,
|
151 |
+
"input_image": "{}".format(image_url),
|
152 |
+
"video_length": "14_frames_with_svd",
|
153 |
+
"sizing_strategy": "maintain_aspect_ratio",
|
154 |
+
"motion_bucket_id": 127,
|
155 |
+
"frames_per_second": 6
|
156 |
+
}
|
157 |
+
|
158 |
+
output = replicate.run(
|
159 |
+
f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
|
160 |
+
input=input,
|
161 |
+
)
|
162 |
+
if isinstance(output, list):
|
163 |
+
result_url = output[0]
|
164 |
+
else:
|
165 |
+
result_url = output
|
166 |
+
print(self.model_name)
|
167 |
+
print(result_url)
|
168 |
+
# response = requests.get(result_url)
|
169 |
+
# result = Image.open(io.BytesIO(response.content))
|
170 |
+
|
171 |
+
# for event in handler.iter_events(with_logs=True):
|
172 |
+
# if isinstance(event, fal_client.InProgress):
|
173 |
+
# print('Request in progress')
|
174 |
+
# print(event.logs)
|
175 |
+
|
176 |
+
# result = handler.get()
|
177 |
+
# print("result video: ====")
|
178 |
+
# print(result)
|
179 |
+
# result_url = result['video']['url']
|
180 |
+
# return result_url
|
181 |
+
return result_url
|
182 |
+
else:
|
183 |
+
raise ValueError("model_type must be text2image or image2image")
|
184 |
+
|
185 |
+
|
186 |
+
def load_replicate_model(model_name, model_type):
|
187 |
+
return ReplicateModel(model_name, model_type)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
model_name = 'replicate_zeroscope-v2-xl_text2video'
|
192 |
+
model_source, model_name, model_type = model_name.split("_")
|
193 |
+
pipe = load_replicate_model(model_name, model_type)
|
194 |
+
prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
|
195 |
+
result = pipe(prompt=prompt)
|
serve/Arial.ttf
ADDED
Binary file (155 kB). View file
|
|
serve/Ksort.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
3 |
+
import os
|
4 |
+
from .constants import KSORT_IMAGE_DIR
|
5 |
+
from .constants import COLOR1, COLOR2, COLOR3, COLOR4
|
6 |
+
from .vote_utils import save_any_image
|
7 |
+
from .utils import disable_btn, enable_btn, invisible_btn
|
8 |
+
from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
|
9 |
+
import json
|
10 |
+
|
11 |
+
|
12 |
+
def reset_level(Top_btn):
|
13 |
+
if Top_btn == "Top 1":
|
14 |
+
level = 0
|
15 |
+
elif Top_btn == "Top 2":
|
16 |
+
level = 1
|
17 |
+
elif Top_btn == "Top 3":
|
18 |
+
level = 2
|
19 |
+
elif Top_btn == "Top 4":
|
20 |
+
level = 3
|
21 |
+
return level
|
22 |
+
|
23 |
+
|
24 |
+
def reset_rank(windows, rank, vote_level):
|
25 |
+
if windows == "Model A":
|
26 |
+
rank[0] = vote_level
|
27 |
+
elif windows == "Model B":
|
28 |
+
rank[1] = vote_level
|
29 |
+
elif windows == "Model C":
|
30 |
+
rank[2] = vote_level
|
31 |
+
elif windows == "Model D":
|
32 |
+
rank[3] = vote_level
|
33 |
+
return rank
|
34 |
+
|
35 |
+
|
36 |
+
def reset_btn_rank(windows, rank, btn, vote_level):
|
37 |
+
if windows == "Model A" and btn == "1":
|
38 |
+
rank[0] = 0
|
39 |
+
elif windows == "Model A" and btn == "2":
|
40 |
+
rank[0] = 1
|
41 |
+
elif windows == "Model A" and btn == "3":
|
42 |
+
rank[0] = 2
|
43 |
+
elif windows == "Model A" and btn == "4":
|
44 |
+
rank[0] = 3
|
45 |
+
elif windows == "Model B" and btn == "1":
|
46 |
+
rank[1] = 0
|
47 |
+
elif windows == "Model B" and btn == "2":
|
48 |
+
rank[1] = 1
|
49 |
+
elif windows == "Model B" and btn == "3":
|
50 |
+
rank[1] = 2
|
51 |
+
elif windows == "Model B" and btn == "4":
|
52 |
+
rank[1] = 3
|
53 |
+
elif windows == "Model C" and btn == "1":
|
54 |
+
rank[2] = 0
|
55 |
+
elif windows == "Model C" and btn == "2":
|
56 |
+
rank[2] = 1
|
57 |
+
elif windows == "Model C" and btn == "3":
|
58 |
+
rank[2] = 2
|
59 |
+
elif windows == "Model C" and btn == "4":
|
60 |
+
rank[2] = 3
|
61 |
+
elif windows == "Model D" and btn == "1":
|
62 |
+
rank[3] = 0
|
63 |
+
elif windows == "Model D" and btn == "2":
|
64 |
+
rank[3] = 1
|
65 |
+
elif windows == "Model D" and btn == "3":
|
66 |
+
rank[3] = 2
|
67 |
+
elif windows == "Model D" and btn == "4":
|
68 |
+
rank[3] = 3
|
69 |
+
if btn == "1":
|
70 |
+
vote_level = 0
|
71 |
+
elif btn == "2":
|
72 |
+
vote_level = 1
|
73 |
+
elif btn == "3":
|
74 |
+
vote_level = 2
|
75 |
+
elif btn == "4":
|
76 |
+
vote_level = 3
|
77 |
+
return (rank, vote_level)
|
78 |
+
|
79 |
+
|
80 |
+
def reset_vote_text(rank):
|
81 |
+
rank_str = ""
|
82 |
+
for i in range(len(rank)):
|
83 |
+
if rank[i] == None:
|
84 |
+
rank_str = rank_str + str(rank[i])
|
85 |
+
else:
|
86 |
+
rank_str = rank_str + str(rank[i]+1)
|
87 |
+
rank_str = rank_str + " "
|
88 |
+
return rank_str
|
89 |
+
|
90 |
+
|
91 |
+
def clear_rank(rank, vote_level):
|
92 |
+
for i in range(len(rank)):
|
93 |
+
rank[i] = None
|
94 |
+
vote_level = 0
|
95 |
+
return rank, vote_level
|
96 |
+
|
97 |
+
|
98 |
+
def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
|
99 |
+
for i in range(len(rank)):
|
100 |
+
rank[i] = None
|
101 |
+
vote_level = 0
|
102 |
+
return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
|
103 |
+
|
104 |
+
|
105 |
+
def reset_submit(rank):
|
106 |
+
for i in range(len(rank)):
|
107 |
+
if rank[i] == None:
|
108 |
+
return disable_btn
|
109 |
+
return enable_btn
|
110 |
+
|
111 |
+
|
112 |
+
def reset_mode(mode):
|
113 |
+
|
114 |
+
if mode == "Best":
|
115 |
+
return (gr.update(visible=False, interactive=False),) * 5 + \
|
116 |
+
(gr.update(visible=True, interactive=True),) * 16 + \
|
117 |
+
(gr.update(visible=True, interactive=True),) * 3 + \
|
118 |
+
(gr.Textbox(value="Rank", visible=False, interactive=False),)
|
119 |
+
elif mode == "Rank":
|
120 |
+
return (gr.update(visible=True, interactive=True),) * 5 + \
|
121 |
+
(gr.update(visible=False, interactive=False),) * 16 + \
|
122 |
+
(gr.update(visible=True, interactive=False),) * 2 + \
|
123 |
+
(gr.update(visible=True, interactive=True),) + \
|
124 |
+
(gr.Textbox(value="Best", visible=False, interactive=False),)
|
125 |
+
else:
|
126 |
+
raise ValueError("Undefined mode")
|
127 |
+
|
128 |
+
|
129 |
+
def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
|
130 |
+
return generate_ig0, generate_ig1, generate_ig2, generate_ig3
|
131 |
+
|
132 |
+
|
133 |
+
def get_json_filename(conv_id):
|
134 |
+
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
|
135 |
+
if not os.path.exists(output_dir):
|
136 |
+
os.makedirs(output_dir)
|
137 |
+
output_file = os.path.join(output_dir, "information.json")
|
138 |
+
# name = os.path.join(KSORT_IMAGE_DIR, f"{conv_id}/json/information.json")
|
139 |
+
print(output_file)
|
140 |
+
return output_file
|
141 |
+
|
142 |
+
|
143 |
+
def get_img_filename(conv_id, i):
|
144 |
+
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
|
145 |
+
if not os.path.exists(output_dir):
|
146 |
+
os.makedirs(output_dir)
|
147 |
+
output_file = os.path.join(output_dir, f"{i}.jpg")
|
148 |
+
print(output_file)
|
149 |
+
return output_file
|
150 |
+
|
151 |
+
|
152 |
+
def vote_submit(states, textbox, rank, request: gr.Request):
|
153 |
+
conv_id = states[0].conv_id
|
154 |
+
|
155 |
+
for i in range(len(states)):
|
156 |
+
output_file = get_img_filename(conv_id, i)
|
157 |
+
save_any_image(states[i].output, output_file)
|
158 |
+
with open(get_json_filename(conv_id), "a") as fout:
|
159 |
+
data = {
|
160 |
+
"models_name": [x.model_name for x in states],
|
161 |
+
"img_rank": [x for x in rank],
|
162 |
+
"prompt": [textbox],
|
163 |
+
}
|
164 |
+
fout.write(json.dumps(data) + "\n")
|
165 |
+
|
166 |
+
|
167 |
+
def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
168 |
+
conv_id = states[0].conv_id
|
169 |
+
output_dir = create_remote_directory(conv_id)
|
170 |
+
# upload_image(states, output_dir)
|
171 |
+
|
172 |
+
data = {
|
173 |
+
"models_name": [x.model_name for x in states],
|
174 |
+
"img_rank": [x for x in rank],
|
175 |
+
"prompt": [textbox],
|
176 |
+
"user_info": {"name": [user_name], "institution": [user_institution]},
|
177 |
+
}
|
178 |
+
output_file = os.path.join(output_dir, "result.json")
|
179 |
+
# upload_informance(data, output_file)
|
180 |
+
upload_ssh_all(states, output_dir, data, output_file)
|
181 |
+
|
182 |
+
from .update_skill import update_skill
|
183 |
+
update_skill(rank, [x.model_name for x in states])
|
184 |
+
|
185 |
+
|
186 |
+
def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
|
187 |
+
conv_id = states[0].conv_id
|
188 |
+
output_dir = create_remote_directory(conv_id, video=True)
|
189 |
+
|
190 |
+
data = {
|
191 |
+
"models_name": [x.model_name for x in states],
|
192 |
+
"video_rank": [x for x in rank],
|
193 |
+
"prompt": [textbox],
|
194 |
+
"prompt_path": [prompt_path],
|
195 |
+
"video_path": [x.output for x in states],
|
196 |
+
"user_info": {"name": [user_name], "institution": [user_institution]},
|
197 |
+
}
|
198 |
+
output_file = os.path.join(output_dir, "result.json")
|
199 |
+
|
200 |
+
upload_ssh_data(data, output_file)
|
201 |
+
|
202 |
+
from .update_skill_video import update_skill_video
|
203 |
+
update_skill_video(rank, [x.model_name for x in states])
|
204 |
+
|
205 |
+
|
206 |
+
def submit_response_igm(
|
207 |
+
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
|
208 |
+
):
|
209 |
+
# vote_submit([state0, state1, state2, state3], textbox, rank, request)
|
210 |
+
vote_ssh_submit([state0, state1, state2, state3], textbox, rank, user_name, user_institution)
|
211 |
+
if model_selector0 == "":
|
212 |
+
return (disable_btn,) * 6 + (
|
213 |
+
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
|
214 |
+
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
|
215 |
+
gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
|
216 |
+
gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
|
217 |
+
) + (disable_btn,)
|
218 |
+
else:
|
219 |
+
return (disable_btn,) * 6 + (
|
220 |
+
gr.Markdown(state0.model_name, visible=True),
|
221 |
+
gr.Markdown(state1.model_name, visible=True),
|
222 |
+
gr.Markdown(state2.model_name, visible=True),
|
223 |
+
gr.Markdown(state3.model_name, visible=True)
|
224 |
+
) + (disable_btn,)
|
225 |
+
|
226 |
+
|
227 |
+
def submit_response_vg(
|
228 |
+
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
|
229 |
+
):
|
230 |
+
vote_video_ssh_submit([state0, state1, state2, state3], textbox, prompt_path, rank, user_name, user_institution)
|
231 |
+
if model_selector0 == "":
|
232 |
+
return (disable_btn,) * 6 + (
|
233 |
+
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
|
234 |
+
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
|
235 |
+
gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
|
236 |
+
gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
|
237 |
+
) + (disable_btn,)
|
238 |
+
else:
|
239 |
+
return (disable_btn,) * 6 + (
|
240 |
+
gr.Markdown(state0.model_name, visible=True),
|
241 |
+
gr.Markdown(state1.model_name, visible=True),
|
242 |
+
gr.Markdown(state2.model_name, visible=True),
|
243 |
+
gr.Markdown(state3.model_name, visible=True)
|
244 |
+
) + (disable_btn,)
|
245 |
+
|
246 |
+
|
247 |
+
def submit_response_rank_igm(
|
248 |
+
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
249 |
+
):
|
250 |
+
print(rank)
|
251 |
+
if right_vote_text == "right":
|
252 |
+
# vote_submit([state0, state1, state2, state3], textbox, rank, request)
|
253 |
+
vote_ssh_submit([state0, state1, state2, state3], textbox, rank, user_name, user_institution)
|
254 |
+
if model_selector0 == "":
|
255 |
+
return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
|
256 |
+
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
|
257 |
+
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
|
258 |
+
gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
|
259 |
+
gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
|
263 |
+
gr.Markdown(state0.model_name, visible=True),
|
264 |
+
gr.Markdown(state1.model_name, visible=True),
|
265 |
+
gr.Markdown(state2.model_name, visible=True),
|
266 |
+
gr.Markdown(state3.model_name, visible=True)
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
270 |
+
|
271 |
+
|
272 |
+
def submit_response_rank_vg(
|
273 |
+
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
274 |
+
):
|
275 |
+
print(rank)
|
276 |
+
if right_vote_text == "right":
|
277 |
+
vote_video_ssh_submit([state0, state1, state2, state3], textbox, prompt_path, rank, user_name, user_institution)
|
278 |
+
if model_selector0 == "":
|
279 |
+
return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
|
280 |
+
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
|
281 |
+
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True),
|
282 |
+
gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True),
|
283 |
+
gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True)
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + (
|
287 |
+
gr.Markdown(state0.model_name, visible=True),
|
288 |
+
gr.Markdown(state1.model_name, visible=True),
|
289 |
+
gr.Markdown(state2.model_name, visible=True),
|
290 |
+
gr.Markdown(state3.model_name, visible=True)
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
294 |
+
|
295 |
+
|
296 |
+
def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
|
297 |
+
rank_list = [char for char in vote_textbox if char.isdigit()]
|
298 |
+
generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
|
299 |
+
chatbot = []
|
300 |
+
rank = [None, None, None, None]
|
301 |
+
if len(rank_list) != 4:
|
302 |
+
return generate_ig + ["error rank"] + ["wrong"] + [rank]
|
303 |
+
for num in range(len(rank_list)):
|
304 |
+
if rank_list[num] in ['1', '2', '3', '4']:
|
305 |
+
base_image = Image.fromarray(generate_ig[num]).convert("RGBA")
|
306 |
+
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
|
307 |
+
if rank_list[num] == '1':
|
308 |
+
border_color = COLOR1
|
309 |
+
elif rank_list[num] == '2':
|
310 |
+
border_color = COLOR2
|
311 |
+
elif rank_list[num] == '3':
|
312 |
+
border_color = COLOR3
|
313 |
+
elif rank_list[num] == '4':
|
314 |
+
border_color = COLOR4
|
315 |
+
border_size = 10 # Size of the border
|
316 |
+
base_image = ImageOps.expand(base_image, border=border_size, fill=border_color)
|
317 |
+
|
318 |
+
draw = ImageDraw.Draw(base_image)
|
319 |
+
font = ImageFont.truetype("./serve/Arial.ttf", 66)
|
320 |
+
text_position = (180, 25)
|
321 |
+
if rank_list[num] == '1':
|
322 |
+
text_color = COLOR1
|
323 |
+
draw.text(text_position, Top1_text, font=font, fill=text_color)
|
324 |
+
elif rank_list[num] == '2':
|
325 |
+
text_color = COLOR2
|
326 |
+
draw.text(text_position, Top2_text, font=font, fill=text_color)
|
327 |
+
elif rank_list[num] == '3':
|
328 |
+
text_color = COLOR3
|
329 |
+
draw.text(text_position, Top3_text, font=font, fill=text_color)
|
330 |
+
elif rank_list[num] == '4':
|
331 |
+
text_color = COLOR4
|
332 |
+
draw.text(text_position, Top4_text, font=font, fill=text_color)
|
333 |
+
base_image = base_image.convert("RGB")
|
334 |
+
chatbot.append(base_image.copy())
|
335 |
+
else:
|
336 |
+
return generate_ig + ["error rank"] + ["wrong"] + [rank]
|
337 |
+
rank_str = ""
|
338 |
+
for str_num in rank_list:
|
339 |
+
rank_str = rank_str + str_num
|
340 |
+
rank_str = rank_str + " "
|
341 |
+
rank = [int(x) for x in rank_list]
|
342 |
+
|
343 |
+
return chatbot + [rank_str] + ["right"] + [rank]
|
344 |
+
|
345 |
+
|
346 |
+
def text_response_rank_vg(vote_textbox):
|
347 |
+
rank_list = [char for char in vote_textbox if char.isdigit()]
|
348 |
+
rank = [None, None, None, None]
|
349 |
+
if len(rank_list) != 4:
|
350 |
+
return ["error rank"] + ["wrong"] + [rank]
|
351 |
+
for num in range(len(rank_list)):
|
352 |
+
if rank_list[num] in ['1', '2', '3', '4']:
|
353 |
+
continue
|
354 |
+
else:
|
355 |
+
return ["error rank"] + ["wrong"] + [rank]
|
356 |
+
rank_str = ""
|
357 |
+
for str_num in rank_list:
|
358 |
+
rank_str = rank_str + str_num
|
359 |
+
rank_str = rank_str + " "
|
360 |
+
rank = [int(x) for x in rank_list]
|
361 |
+
|
362 |
+
return [rank_str] + ["right"] + [rank]
|
363 |
+
|
364 |
+
|
365 |
+
def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
|
366 |
+
base_image = Image.fromarray(image).convert("RGBA")
|
367 |
+
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
|
368 |
+
if vote_level == 0:
|
369 |
+
border_color = COLOR1
|
370 |
+
elif vote_level == 1:
|
371 |
+
border_color = COLOR2
|
372 |
+
elif vote_level == 2:
|
373 |
+
border_color = COLOR3
|
374 |
+
elif vote_level == 3:
|
375 |
+
border_color = COLOR4
|
376 |
+
border_size = 10 # Size of the border
|
377 |
+
base_image = ImageOps.expand(base_image, border=border_size, fill=border_color)
|
378 |
+
|
379 |
+
draw = ImageDraw.Draw(base_image)
|
380 |
+
font = ImageFont.truetype("./serve/Arial.ttf", 66)
|
381 |
+
|
382 |
+
text_position = (180, 25)
|
383 |
+
if vote_level == 0:
|
384 |
+
text_color = COLOR1
|
385 |
+
draw.text(text_position, Top1_text, font=font, fill=text_color)
|
386 |
+
elif vote_level == 1:
|
387 |
+
text_color = COLOR2
|
388 |
+
draw.text(text_position, Top2_text, font=font, fill=text_color)
|
389 |
+
elif vote_level == 2:
|
390 |
+
text_color = COLOR3
|
391 |
+
draw.text(text_position, Top3_text, font=font, fill=text_color)
|
392 |
+
elif vote_level == 3:
|
393 |
+
text_color = COLOR4
|
394 |
+
draw.text(text_position, Top4_text, font=font, fill=text_color)
|
395 |
+
|
396 |
+
base_image = base_image.convert("RGB")
|
397 |
+
return base_image
|
398 |
+
|
399 |
+
|
400 |
+
def add_green_border(image):
|
401 |
+
border_color = (0, 255, 0) # RGB for green
|
402 |
+
border_size = 10 # Size of the border
|
403 |
+
img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
|
404 |
+
return img_with_border
|
405 |
+
|
406 |
+
|
407 |
+
def check_textbox(textbox):
|
408 |
+
if textbox=="":
|
409 |
+
return False
|
410 |
+
else:
|
411 |
+
return True
|
serve/__init__.py
ADDED
File without changes
|
serve/__pycache__/Ksort.cpython-310.pyc
ADDED
Binary file (11.5 kB). View file
|
|
serve/__pycache__/Ksort.cpython-39.pyc
ADDED
Binary file (11.7 kB). View file
|
|
serve/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
serve/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
serve/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (1.85 kB). View file
|
|
serve/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (1.84 kB). View file
|
|
serve/__pycache__/gradio_web.cpython-310.pyc
ADDED
Binary file (14.2 kB). View file
|
|
serve/__pycache__/gradio_web.cpython-39.pyc
ADDED
Binary file (14 kB). View file
|
|
serve/__pycache__/gradio_web_bbox.cpython-310.pyc
ADDED
Binary file (14.5 kB). View file
|
|
serve/__pycache__/leaderboard.cpython-310.pyc
ADDED
Binary file (7.97 kB). View file
|
|
serve/__pycache__/log_utils.cpython-310.pyc
ADDED
Binary file (4.06 kB). View file
|
|
serve/__pycache__/log_utils.cpython-39.pyc
ADDED
Binary file (4.05 kB). View file
|
|
serve/__pycache__/update_skill.cpython-310.pyc
ADDED
Binary file (4 kB). View file
|
|
serve/__pycache__/upload.cpython-310.pyc
ADDED
Binary file (8.22 kB). View file
|
|
serve/__pycache__/upload.cpython-39.pyc
ADDED
Binary file (8.28 kB). View file
|
|
serve/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (8.54 kB). View file
|
|
serve/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (9.14 kB). View file
|
|
serve/__pycache__/vote_utils.cpython-310.pyc
ADDED
Binary file (30.4 kB). View file
|
|
serve/__pycache__/vote_utils.cpython-39.pyc
ADDED
Binary file (33.7 kB). View file
|
|
serve/button.css
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* style.css */
|
2 |
+
.custom-button {
|
3 |
+
background-color: greenyellow; /* 背景颜色 */
|
4 |
+
color: white; /* 文字颜色 */
|
5 |
+
border: none; /* 无边框 */
|
6 |
+
padding: 10px 20px; /* 内边距 */
|
7 |
+
text-align: center; /* 文本居中 */
|
8 |
+
text-decoration: none; /* 无下划线 */
|
9 |
+
display: inline-block; /* 行内块 */
|
10 |
+
font-size: 16px; /* 字体大小 */
|
11 |
+
margin: 4px 2px; /* 外边距 */
|
12 |
+
cursor: pointer; /* 鼠标指针 */
|
13 |
+
border-radius: 5px; /* 圆角边框 */
|
14 |
+
}
|
15 |
+
|
16 |
+
.custom-button:hover {
|
17 |
+
background-color: darkgreen; /* 悬停时的背景颜色 */
|
18 |
+
}
|
19 |
+
/* css = """
|
20 |
+
#warning {background: red;}
|
21 |
+
|
22 |
+
.feedback {font-size: 24px !important;}
|
23 |
+
.feedback textarea {font-size: 24px !important;}
|
24 |
+
""" */
|
serve/constants.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
LOGDIR = os.getenv("LOGDIR", "./ksort-logs/vote_log")
|
4 |
+
IMAGE_DIR = os.getenv("IMAGE_DIR", f"{LOGDIR}/images")
|
5 |
+
KSORT_IMAGE_DIR = os.getenv("KSORT_IMAGE_DIR", f"{LOGDIR}/ksort_images")
|
6 |
+
VIDEO_DIR = os.getenv("VIDEO_DIR", f"{LOGDIR}/videos")
|
7 |
+
|
8 |
+
SERVER_PORT = os.getenv("SERVER_PORT", 7860)
|
9 |
+
ROOT_PATH = os.getenv("ROOT_PATH", None)
|
10 |
+
ELO_RESULTS_DIR = os.getenv("ELO_RESULTS_DIR", "./arena_elo/results/latest")
|
11 |
+
|
12 |
+
LOG_SERVER = os.getenv("LOG_SERVER", "http://127.0.0.1:22005")
|
13 |
+
LOG_SERVER_SUBDOMAIN = os.getenv("LOG_SERVER_SUBDOMAIN", "/ksort-logs")
|
14 |
+
LOG_SERVER_ADDR = os.getenv("LOG_SERVER_ADDR", f"{LOG_SERVER}{LOG_SERVER_SUBDOMAIN}")
|
15 |
+
# LOG SERVER API ENDPOINTS
|
16 |
+
APPEND_JSON = "append_json"
|
17 |
+
SAVE_IMAGE = "save_image"
|
18 |
+
SAVE_VIDEO = "save_video"
|
19 |
+
SAVE_LOG = "save_log"
|
20 |
+
|
21 |
+
SSH_MSCOCO = os.getenv("SSH_MSCOCO", "/root/MSCOCO_prompt")
|
22 |
+
|
23 |
+
|
24 |
+
SSH_SERVER = os.getenv("SSH_SERVER", "default_value")
|
25 |
+
SSH_PASSWORD = os.getenv("SSH_PASSWORD", "default_value")
|
26 |
+
SSH_PORT = os.getenv("SSH_PORT", "default_value")
|
27 |
+
SSH_USER = os.getenv("SSH_USER", "default_value")
|
28 |
+
SSH_LOG = os.getenv("SSH_LOG", "/home/zhendongucb/ksort/ksort_log/log")
|
29 |
+
SSH_VIDEO_LOG = os.getenv("SSH_LOG", "/home/zhendongucb/ksort/ksort_log/video_log")
|
30 |
+
SSH_SKILL = os.getenv("SSH_SKILL", "/home/zhendongucb/ksort/TrueSkill/trueskill_data.json")
|
31 |
+
SSH_VIDEO_SKILL = os.getenv("SSH_VIDEO_SKILL", "/home/zhendongucb/ksort/TrueSkill/trueskill_video_data.json")
|
32 |
+
|
33 |
+
SSH_CACHE_OPENSOURCE = os.getenv("SSH_CACHE_OPENSOURCE", "/home/zhendongucb/ksort/ksort_video_cache/Kaiyuan/")
|
34 |
+
SSH_CACHE_ADVANCE = os.getenv("SSH_CACHE_ADVANCE", "/home/zhendongucb/ksort/ksort_video_cache/Advance/")
|
35 |
+
SSH_CACHE_PIKA = os.getenv("SSH_CACHE_PIKA", "/home/zhendongucb/ksort/ksort_video_cache/Pika-Beta/")
|
36 |
+
SSH_CACHE_SORA = os.getenv("SSH_CACHE_SORA", "/home/zhendongucb/ksort/ksort_video_cache/Sora/")
|
37 |
+
|
38 |
+
SSH_CACHE_IMAGE = os.getenv("SSH_CACHE_IMAGE", "/home/zhendongucb/ksort/ksort_image_cache/")
|
39 |
+
|
40 |
+
# COLOR1=(128, 214, 255)
|
41 |
+
# COLOR2=(237, 247, 152)
|
42 |
+
# COLOR3=(250, 181, 122)
|
43 |
+
# COLOR4=(240, 104, 104)
|
44 |
+
|
45 |
+
# COLOR1=(112, 161, 215)
|
46 |
+
# COLOR2=(161, 222, 147)
|
47 |
+
# COLOR3=(247, 244, 139)
|
48 |
+
# COLOR4=(244, 124, 124)
|
49 |
+
|
50 |
+
COLOR1=(168, 230, 207)
|
51 |
+
COLOR2=(253, 255, 171)
|
52 |
+
COLOR3=(255, 211, 182)
|
53 |
+
COLOR4=(255, 170, 165)
|
54 |
+
|
55 |
+
# COLOR1=(255, 212, 96)
|
56 |
+
# COLOR2=(240, 123, 63)
|
57 |
+
# COLOR3=(234, 84, 85)
|
58 |
+
# COLOR4=(45, 64, 89)
|
59 |
+
|
60 |
+
# COLOR1=(255, 189, 57)
|
61 |
+
# COLOR2=(230, 28, 93)
|
62 |
+
# COLOR3=(147, 0, 119)
|
63 |
+
# COLOR4=(58, 0, 136)
|
serve/gradio_web.py
ADDED
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import *
|
2 |
+
from .vote_utils import (
|
3 |
+
upvote_last_response_ig as upvote_last_response,
|
4 |
+
downvote_last_response_ig as downvote_last_response,
|
5 |
+
flag_last_response_ig as flag_last_response,
|
6 |
+
leftvote_last_response_igm as leftvote_last_response,
|
7 |
+
left1vote_last_response_igm as left1vote_last_response,
|
8 |
+
rightvote_last_response_igm as rightvote_last_response,
|
9 |
+
right1vote_last_response_igm as right1vote_last_response,
|
10 |
+
tievote_last_response_igm as tievote_last_response,
|
11 |
+
bothbad_vote_last_response_igm as bothbad_vote_last_response,
|
12 |
+
share_click_igm as share_click,
|
13 |
+
generate_ig,
|
14 |
+
generate_ig_museum,
|
15 |
+
generate_igm,
|
16 |
+
generate_igm_museum,
|
17 |
+
generate_igm_annoy,
|
18 |
+
generate_igm_annoy_museum,
|
19 |
+
generate_igm_cache_annoy,
|
20 |
+
share_js
|
21 |
+
)
|
22 |
+
from .Ksort import (
|
23 |
+
add_foreground,
|
24 |
+
reset_level,
|
25 |
+
reset_rank,
|
26 |
+
revote_windows,
|
27 |
+
submit_response_igm,
|
28 |
+
submit_response_rank_igm,
|
29 |
+
reset_submit,
|
30 |
+
clear_rank,
|
31 |
+
reset_mode,
|
32 |
+
reset_chatbot,
|
33 |
+
reset_btn_rank,
|
34 |
+
reset_vote_text,
|
35 |
+
text_response_rank_igm,
|
36 |
+
check_textbox,
|
37 |
+
)
|
38 |
+
|
39 |
+
from functools import partial
|
40 |
+
from .upload import get_random_mscoco_prompt
|
41 |
+
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD
|
42 |
+
from serve.upload import get_random_mscoco_prompt, create_ssh_client
|
43 |
+
from serve.update_skill import create_ssh_skill_client
|
44 |
+
from model.matchmaker import create_ssh_matchmaker_client
|
45 |
+
def set_ssh():
|
46 |
+
create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
47 |
+
create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
48 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
49 |
+
|
50 |
+
def build_side_by_side_ui_anony(models):
|
51 |
+
notice_markdown = """
|
52 |
+
# ⚔️ K-Sort Arena (Text-to-Image Generation) ⚔️
|
53 |
+
## 📜 Rules
|
54 |
+
- Input a prompt for four anonymized models (e.g., SD, SDXL, OpenJourney for Text-guided Image Generation) and vote on their outputs.
|
55 |
+
- Two voting modes available: Rank Mode and Best Mode. Switch freely between modes. Please note that ties are always allowed. In ranking mode, users can input rankings like 1 3 3 1. Any invalid rankings, such as 1 4 4 1, will be automatically corrected during post-processing.
|
56 |
+
- Users are encouraged to make evaluations based on subjective preferences. Evaluation criteria: Alignment (50%) + Aesthetics (50%).
|
57 |
+
- Alignment includes: Entity Matching (30%) + Style Matching (20%);
|
58 |
+
- Aesthetics includes: Photorealism (30%) + Light and Shadow (10%) + Absence of Artifacts (10%).
|
59 |
+
|
60 |
+
## 👇 Generating now!
|
61 |
+
- Note: Due to the API's image safety checks, errors may occur. If this happens, please re-enter a different prompt.
|
62 |
+
- At times, high API concurrency can cause congestion, potentially resulting in a generation time of up to 1.5 minutes per image. Thank you for your patience.
|
63 |
+
"""
|
64 |
+
|
65 |
+
model_list = models.model_ig_list
|
66 |
+
|
67 |
+
state0 = gr.State()
|
68 |
+
state1 = gr.State()
|
69 |
+
state2 = gr.State()
|
70 |
+
state3 = gr.State()
|
71 |
+
|
72 |
+
gen_func = partial(generate_igm_annoy, models.generate_image_ig_parallel_anony)
|
73 |
+
gen_cache_func = partial(generate_igm_cache_annoy, models.generate_image_ig_cache_anony)
|
74 |
+
|
75 |
+
# gen_func_random = partial(generate_igm_annoy_museum, models.generate_image_ig_museum_parallel_anony)
|
76 |
+
|
77 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
78 |
+
|
79 |
+
with gr.Group(elem_id="share-region-anony"):
|
80 |
+
with gr.Accordion("🔍 Expand to see all Arena players", open=False):
|
81 |
+
model_description_md = get_model_description_md(model_list)
|
82 |
+
gr.Markdown(model_description_md, elem_id="model_description_markdown")
|
83 |
+
with gr.Row():
|
84 |
+
with gr.Column():
|
85 |
+
chatbot_left = gr.Image(width=512, label = "Model A")
|
86 |
+
with gr.Column():
|
87 |
+
chatbot_left1 = gr.Image(width=512, label = "Model B")
|
88 |
+
with gr.Column():
|
89 |
+
chatbot_right = gr.Image(width=512, label = "Model C")
|
90 |
+
with gr.Column():
|
91 |
+
chatbot_right1 = gr.Image(width=512, label = "Model D")
|
92 |
+
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column():
|
95 |
+
model_selector_left = gr.Markdown("", visible=False)
|
96 |
+
with gr.Column():
|
97 |
+
model_selector_left1 = gr.Markdown("", visible=False)
|
98 |
+
with gr.Column():
|
99 |
+
model_selector_right = gr.Markdown("", visible=False)
|
100 |
+
with gr.Column():
|
101 |
+
model_selector_right1 = gr.Markdown("", visible=False)
|
102 |
+
with gr.Row():
|
103 |
+
slow_warning = gr.Markdown("", elem_id="notice_markdown")
|
104 |
+
|
105 |
+
with gr.Row(elem_classes="row"):
|
106 |
+
with gr.Column(scale=1, min_width=10):
|
107 |
+
leftvote_btn = gr.Button(
|
108 |
+
value="A is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
109 |
+
)
|
110 |
+
with gr.Column(scale=1, min_width=10):
|
111 |
+
left1vote_btn = gr.Button(
|
112 |
+
value="B is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
113 |
+
)
|
114 |
+
with gr.Column(scale=1, min_width=10):
|
115 |
+
rightvote_btn = gr.Button(
|
116 |
+
value="C is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
117 |
+
)
|
118 |
+
with gr.Column(scale=1, min_width=10):
|
119 |
+
right1vote_btn = gr.Button(
|
120 |
+
value="D is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
121 |
+
)
|
122 |
+
with gr.Column(scale=1, min_width=10):
|
123 |
+
tie_btn = gr.Button(
|
124 |
+
value="🤝 Tie", visible=False, interactive=False, elem_id="btncolor2", elem_classes="best-button"
|
125 |
+
)
|
126 |
+
|
127 |
+
with gr.Row():
|
128 |
+
with gr.Blocks():
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Column(scale=1, min_width=10):
|
131 |
+
A1_btn = gr.Button(
|
132 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
133 |
+
)
|
134 |
+
with gr.Column(scale=1, min_width=10):
|
135 |
+
A2_btn = gr.Button(
|
136 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
137 |
+
)
|
138 |
+
with gr.Column(scale=1, min_width=10):
|
139 |
+
A3_btn = gr.Button(
|
140 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
141 |
+
)
|
142 |
+
with gr.Column(scale=1, min_width=10):
|
143 |
+
A4_btn = gr.Button(
|
144 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
145 |
+
)
|
146 |
+
with gr.Blocks():
|
147 |
+
with gr.Row():
|
148 |
+
with gr.Column(scale=1, min_width=10):
|
149 |
+
B1_btn = gr.Button(
|
150 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
151 |
+
)
|
152 |
+
with gr.Column(scale=1, min_width=10):
|
153 |
+
B2_btn = gr.Button(
|
154 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
155 |
+
)
|
156 |
+
with gr.Column(scale=1, min_width=10):
|
157 |
+
B3_btn = gr.Button(
|
158 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
159 |
+
)
|
160 |
+
with gr.Column(scale=1, min_width=10):
|
161 |
+
B4_btn = gr.Button(
|
162 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
163 |
+
)
|
164 |
+
with gr.Blocks():
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column(scale=1, min_width=10):
|
167 |
+
C1_btn = gr.Button(
|
168 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
169 |
+
)
|
170 |
+
with gr.Column(scale=1, min_width=10):
|
171 |
+
C2_btn = gr.Button(
|
172 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
173 |
+
)
|
174 |
+
with gr.Column(scale=1, min_width=10):
|
175 |
+
C3_btn = gr.Button(
|
176 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
177 |
+
)
|
178 |
+
with gr.Column(scale=1, min_width=10):
|
179 |
+
C4_btn = gr.Button(
|
180 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
181 |
+
)
|
182 |
+
with gr.Blocks():
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column(scale=1, min_width=10):
|
185 |
+
D1_btn = gr.Button(
|
186 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
187 |
+
)
|
188 |
+
with gr.Column(scale=1, min_width=10):
|
189 |
+
D2_btn = gr.Button(
|
190 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
191 |
+
)
|
192 |
+
with gr.Column(scale=1, min_width=10):
|
193 |
+
D3_btn = gr.Button(
|
194 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
195 |
+
)
|
196 |
+
with gr.Column(scale=1, min_width=10):
|
197 |
+
D4_btn = gr.Button(
|
198 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
199 |
+
)
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
vote_textbox = gr.Textbox(
|
203 |
+
show_label=False,
|
204 |
+
placeholder="👉 Enter your rank (you can use buttons above, or directly type here, e.g. 1 2 3 4)",
|
205 |
+
container=True,
|
206 |
+
elem_id="input_box",
|
207 |
+
visible=False,
|
208 |
+
)
|
209 |
+
vote_submit_btn = gr.Button(value="Submit", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
|
210 |
+
vote_mode_btn = gr.Button(value="🔄 Mode", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
|
211 |
+
|
212 |
+
with gr.Row():
|
213 |
+
textbox = gr.Textbox(
|
214 |
+
show_label=False,
|
215 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
216 |
+
container=True,
|
217 |
+
elem_id="input_box",
|
218 |
+
)
|
219 |
+
# send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue", elem_classes="send-button")
|
220 |
+
# draw_btn = gr.Button(value="🎲 Random sample", variant="primary", scale=0, elem_id="btnblue", elem_classes="send-button")
|
221 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue")
|
222 |
+
draw_btn = gr.Button(value="🎲 Random Prompt", variant="primary", scale=0, elem_id="btnblue")
|
223 |
+
with gr.Row():
|
224 |
+
cache_btn = gr.Button(value="🎲 Random Sample", interactive=True)
|
225 |
+
with gr.Row():
|
226 |
+
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
|
227 |
+
# regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
228 |
+
# share_btn = gr.Button(value="📷 Share")
|
229 |
+
with gr.Blocks():
|
230 |
+
with gr.Row(elem_id="centered-text"): #
|
231 |
+
user_info = gr.Markdown("User information (to appear on the contributor leaderboard)", visible=True, elem_id="centered-text") #, elem_id="centered-text"
|
232 |
+
# with gr.Blocks():
|
233 |
+
# name = gr.Markdown("Name", visible=True)
|
234 |
+
user_name = gr.Textbox(show_label=False,placeholder="👉 Enter your name (optional)", elem_classes="custom-width")
|
235 |
+
# with gr.Blocks():
|
236 |
+
# institution = gr.Markdown("Institution", visible=True)
|
237 |
+
user_institution = gr.Textbox(show_label=False,placeholder="👉 Enter your affiliation (optional)", elem_classes="custom-width")
|
238 |
+
#gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
239 |
+
dummy_img_output = gr.Image(width=512, visible=False)
|
240 |
+
gr.Examples(
|
241 |
+
examples=[["A train crossing a bridge that is going over a body of water.", os.path.join("./examples", "example1.jpg")],
|
242 |
+
["The man in the business suit wears a striped blue and white tie.", os.path.join("./examples", "example2.jpg")],
|
243 |
+
["A skier stands on a small ledge in the snow.",os.path.join("./examples", "example3.jpg")],
|
244 |
+
["The bathroom with green tile and a red shower curtain.", os.path.join("./examples", "example4.jpg")]],
|
245 |
+
inputs = [textbox, dummy_img_output])
|
246 |
+
|
247 |
+
order_btn_list = [textbox, send_btn, draw_btn, cache_btn, clear_btn]
|
248 |
+
vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
249 |
+
A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
|
250 |
+
vote_textbox, vote_submit_btn, vote_mode_btn]
|
251 |
+
|
252 |
+
generate_ig0 = gr.Image(width=512, label = "generate A", visible=False, interactive=False)
|
253 |
+
generate_ig1 = gr.Image(width=512, label = "generate B", visible=False, interactive=False)
|
254 |
+
generate_ig2 = gr.Image(width=512, label = "generate C", visible=False, interactive=False)
|
255 |
+
generate_ig3 = gr.Image(width=512, label = "generate D", visible=False, interactive=False)
|
256 |
+
dummy_left_model = gr.State("")
|
257 |
+
dummy_left1_model = gr.State("")
|
258 |
+
dummy_right_model = gr.State("")
|
259 |
+
dummy_right1_model = gr.State("")
|
260 |
+
|
261 |
+
ig_rank = [None, None, None, None]
|
262 |
+
bastA_rank = [0, 3, 3, 3]
|
263 |
+
bastB_rank = [3, 0, 3, 3]
|
264 |
+
bastC_rank = [3, 3, 0, 3]
|
265 |
+
bastD_rank = [3, 3, 3, 0]
|
266 |
+
tie_rank = [0, 0, 0, 0]
|
267 |
+
bad_rank = [3, 3, 3, 3]
|
268 |
+
rank = gr.State(ig_rank)
|
269 |
+
rankA = gr.State(bastA_rank)
|
270 |
+
rankB = gr.State(bastB_rank)
|
271 |
+
rankC = gr.State(bastC_rank)
|
272 |
+
rankD = gr.State(bastD_rank)
|
273 |
+
rankTie = gr.State(tie_rank)
|
274 |
+
rankBad = gr.State(bad_rank)
|
275 |
+
Top1_text = gr.Textbox(value="Top 1", visible=False, interactive=False)
|
276 |
+
Top2_text = gr.Textbox(value="Top 2", visible=False, interactive=False)
|
277 |
+
Top3_text = gr.Textbox(value="Top 3", visible=False, interactive=False)
|
278 |
+
Top4_text = gr.Textbox(value="Top 4", visible=False, interactive=False)
|
279 |
+
window1_text = gr.Textbox(value="Model A", visible=False, interactive=False)
|
280 |
+
window2_text = gr.Textbox(value="Model B", visible=False, interactive=False)
|
281 |
+
window3_text = gr.Textbox(value="Model C", visible=False, interactive=False)
|
282 |
+
window4_text = gr.Textbox(value="Model D", visible=False, interactive=False)
|
283 |
+
vote_level = gr.Number(value=0, visible=False, interactive=False)
|
284 |
+
# Top1_btn.click(reset_level, inputs=[Top1_text], outputs=[vote_level])
|
285 |
+
# Top2_btn.click(reset_level, inputs=[Top2_text], outputs=[vote_level])
|
286 |
+
# Top3_btn.click(reset_level, inputs=[Top3_text], outputs=[vote_level])
|
287 |
+
# Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level])
|
288 |
+
vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False)
|
289 |
+
right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False)
|
290 |
+
cache_mode = gr.Textbox(value="True", visible=False, interactive=False)
|
291 |
+
|
292 |
+
textbox.submit(
|
293 |
+
disable_order_buttons,
|
294 |
+
inputs=[textbox],
|
295 |
+
outputs=order_btn_list
|
296 |
+
).then(
|
297 |
+
gen_func,
|
298 |
+
inputs=[state0, state1, state2, state3, textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
299 |
+
outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
|
300 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
301 |
+
api_name="submit_btn_annony"
|
302 |
+
).then(
|
303 |
+
enable_vote_mode_buttons,
|
304 |
+
inputs=[vote_mode, textbox],
|
305 |
+
outputs=vote_order_list
|
306 |
+
)
|
307 |
+
|
308 |
+
send_btn.click(
|
309 |
+
disable_order_buttons,
|
310 |
+
inputs=[textbox],
|
311 |
+
outputs=order_btn_list
|
312 |
+
).then(
|
313 |
+
gen_func,
|
314 |
+
inputs=[state0, state1, state2, state3, textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
315 |
+
outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
|
316 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
317 |
+
api_name="send_btn_annony"
|
318 |
+
).then(
|
319 |
+
enable_vote_mode_buttons,
|
320 |
+
inputs=[vote_mode, textbox],
|
321 |
+
outputs=vote_order_list
|
322 |
+
)
|
323 |
+
|
324 |
+
cache_btn.click(
|
325 |
+
disable_order_buttons,
|
326 |
+
inputs=[textbox, cache_mode],
|
327 |
+
outputs=order_btn_list
|
328 |
+
).then(
|
329 |
+
gen_cache_func,
|
330 |
+
inputs=[state0, state1, state2, state3, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
331 |
+
outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
|
332 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, textbox],
|
333 |
+
api_name="send_btn_annony"
|
334 |
+
).then(
|
335 |
+
enable_vote_mode_buttons,
|
336 |
+
inputs=[vote_mode, textbox],
|
337 |
+
outputs=vote_order_list
|
338 |
+
)
|
339 |
+
|
340 |
+
draw_btn.click(
|
341 |
+
get_random_mscoco_prompt,
|
342 |
+
inputs=None,
|
343 |
+
outputs=[textbox],
|
344 |
+
api_name="draw_btn_annony"
|
345 |
+
)
|
346 |
+
|
347 |
+
clear_btn.click(
|
348 |
+
clear_history_side_by_side_anony,
|
349 |
+
inputs=None,
|
350 |
+
outputs=[state0, state1, state2, state3, textbox, vote_textbox, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
|
351 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
352 |
+
api_name="clear_btn_annony"
|
353 |
+
).then(
|
354 |
+
enable_order_buttons,
|
355 |
+
inputs=None,
|
356 |
+
outputs=order_btn_list
|
357 |
+
).then(
|
358 |
+
clear_rank,
|
359 |
+
inputs=[rank, vote_level],
|
360 |
+
outputs=[rank, vote_level]
|
361 |
+
).then(
|
362 |
+
disable_vote_mode_buttons,
|
363 |
+
inputs=None,
|
364 |
+
outputs=vote_order_list
|
365 |
+
)
|
366 |
+
|
367 |
+
# regenerate_btn.click(
|
368 |
+
# gen_func,
|
369 |
+
# inputs=[state0, state1, state2, state3, textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
370 |
+
# outputs=[state0, state1, state2, state3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
|
371 |
+
# model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
372 |
+
# api_name="regenerate_btn_annony"
|
373 |
+
# ).then(
|
374 |
+
# enable_best_buttons,
|
375 |
+
# inputs=None,
|
376 |
+
# outputs=btn_list
|
377 |
+
# )
|
378 |
+
vote_mode_btn.click(
|
379 |
+
reset_chatbot,
|
380 |
+
inputs=[vote_mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3],
|
381 |
+
outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1]
|
382 |
+
).then(
|
383 |
+
reset_mode,
|
384 |
+
inputs=[vote_mode],
|
385 |
+
outputs=[leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
386 |
+
A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
|
387 |
+
vote_textbox, vote_submit_btn, vote_mode_btn, vote_mode]
|
388 |
+
)
|
389 |
+
|
390 |
+
vote_textbox.submit(
|
391 |
+
disable_vote,
|
392 |
+
inputs=None,
|
393 |
+
outputs=[vote_submit_btn, vote_mode_btn, \
|
394 |
+
A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn]
|
395 |
+
).then(
|
396 |
+
text_response_rank_igm,
|
397 |
+
inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
|
398 |
+
outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, vote_textbox, right_vote_text, rank]
|
399 |
+
).then(
|
400 |
+
submit_response_rank_igm,
|
401 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rank, right_vote_text, user_name, user_institution],
|
402 |
+
outputs=[A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
|
403 |
+
vote_textbox, vote_submit_btn, vote_mode_btn, right_vote_text, \
|
404 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
405 |
+
api_name="submit_btn_annony"
|
406 |
+
)
|
407 |
+
vote_submit_btn.click(
|
408 |
+
disable_vote,
|
409 |
+
inputs=None,
|
410 |
+
outputs=[vote_submit_btn, vote_mode_btn, \
|
411 |
+
A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn]
|
412 |
+
).then(
|
413 |
+
text_response_rank_igm,
|
414 |
+
inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox],
|
415 |
+
outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, vote_textbox, right_vote_text, rank]
|
416 |
+
).then(
|
417 |
+
submit_response_rank_igm,
|
418 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rank, right_vote_text, user_name, user_institution],
|
419 |
+
outputs=[A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
|
420 |
+
vote_textbox, vote_submit_btn, vote_mode_btn, right_vote_text, \
|
421 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
|
422 |
+
api_name="submit_btn_annony"
|
423 |
+
)
|
424 |
+
|
425 |
+
# Revote_btn.click(
|
426 |
+
# revote_windows,
|
427 |
+
# inputs=[generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level],
|
428 |
+
# outputs=[chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, rank, vote_level]
|
429 |
+
# ).then(
|
430 |
+
# reset_submit,
|
431 |
+
# inputs = [rank],
|
432 |
+
# outputs = [Submit_btn]
|
433 |
+
# )
|
434 |
+
# Submit_btn.click(
|
435 |
+
# submit_response_igm,
|
436 |
+
# inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, rank],
|
437 |
+
# outputs=[textbox, Top1_btn, Top2_btn, Top3_btn, Top4_btn, Revote_btn, Submit_btn, \
|
438 |
+
# model_selector_left, model_selector_left1, model_selector_right, model_selector_right1]
|
439 |
+
# )
|
440 |
+
|
441 |
+
# chatbot_left.select(add_foreground, inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_left]).then(
|
442 |
+
# reset_rank,
|
443 |
+
# inputs = [window1_text, rank, vote_level],
|
444 |
+
# outputs = [rank]
|
445 |
+
# ).then(
|
446 |
+
# reset_submit,
|
447 |
+
# inputs = [rank],
|
448 |
+
# outputs = [Submit_btn]
|
449 |
+
# )
|
450 |
+
# chatbot_left1.select(add_foreground, inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_left1]).then(
|
451 |
+
# reset_rank,
|
452 |
+
# inputs = [window2_text, rank, vote_level],
|
453 |
+
# outputs = [rank]
|
454 |
+
# ).then(
|
455 |
+
# reset_submit,
|
456 |
+
# inputs = [rank],
|
457 |
+
# outputs = [Submit_btn]
|
458 |
+
# )
|
459 |
+
# chatbot_right.select(add_foreground, inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_right]).then(
|
460 |
+
# reset_rank,
|
461 |
+
# inputs = [window3_text, rank, vote_level],
|
462 |
+
# outputs = [rank]
|
463 |
+
# ).then(
|
464 |
+
# reset_submit,
|
465 |
+
# inputs = [rank],
|
466 |
+
# outputs = [Submit_btn]
|
467 |
+
# )
|
468 |
+
# chatbot_right1.select(add_foreground, inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text], outputs=[chatbot_right1]).then(
|
469 |
+
# reset_rank,
|
470 |
+
# inputs = [window4_text, rank, vote_level],
|
471 |
+
# outputs = [rank]
|
472 |
+
# ).then(
|
473 |
+
# reset_submit,
|
474 |
+
# inputs = [rank],
|
475 |
+
# outputs = [Submit_btn]
|
476 |
+
# )
|
477 |
+
|
478 |
+
|
479 |
+
leftvote_btn.click(
|
480 |
+
submit_response_igm,
|
481 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankA, user_name, user_institution],
|
482 |
+
outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
483 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
|
484 |
+
vote_mode_btn]
|
485 |
+
)
|
486 |
+
left1vote_btn.click(
|
487 |
+
submit_response_igm,
|
488 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankB, user_name, user_institution],
|
489 |
+
outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
490 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
|
491 |
+
vote_mode_btn]
|
492 |
+
)
|
493 |
+
rightvote_btn.click(
|
494 |
+
submit_response_igm,
|
495 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankC, user_name, user_institution],
|
496 |
+
outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
497 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
|
498 |
+
vote_mode_btn]
|
499 |
+
)
|
500 |
+
right1vote_btn.click(
|
501 |
+
submit_response_igm,
|
502 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankD, user_name, user_institution],
|
503 |
+
outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
504 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
|
505 |
+
vote_mode_btn]
|
506 |
+
)
|
507 |
+
tie_btn.click(
|
508 |
+
submit_response_igm,
|
509 |
+
inputs=[state0, state1, state2, state3, dummy_left_model, dummy_left1_model, dummy_right_model, dummy_right1_model, textbox, rankTie, user_name, user_institution],
|
510 |
+
outputs=[textbox, leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
511 |
+
model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, \
|
512 |
+
vote_mode_btn]
|
513 |
+
)
|
514 |
+
|
515 |
+
A1_btn.click(
|
516 |
+
reset_btn_rank,
|
517 |
+
inputs=[window1_text, rank, A1_btn, vote_level],
|
518 |
+
outputs=[rank, vote_level]
|
519 |
+
).then(
|
520 |
+
add_foreground,
|
521 |
+
inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
522 |
+
outputs=[chatbot_left]
|
523 |
+
).then(
|
524 |
+
reset_submit,
|
525 |
+
inputs = [rank],
|
526 |
+
outputs = [vote_submit_btn]
|
527 |
+
).then(
|
528 |
+
reset_vote_text,
|
529 |
+
inputs = [rank],
|
530 |
+
outputs = [vote_textbox]
|
531 |
+
)
|
532 |
+
A2_btn.click(
|
533 |
+
reset_btn_rank,
|
534 |
+
inputs=[window1_text, rank, A2_btn, vote_level],
|
535 |
+
outputs=[rank, vote_level]
|
536 |
+
).then(
|
537 |
+
add_foreground,
|
538 |
+
inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
539 |
+
outputs=[chatbot_left]
|
540 |
+
).then(
|
541 |
+
reset_submit,
|
542 |
+
inputs = [rank],
|
543 |
+
outputs = [vote_submit_btn]
|
544 |
+
).then(
|
545 |
+
reset_vote_text,
|
546 |
+
inputs = [rank],
|
547 |
+
outputs = [vote_textbox]
|
548 |
+
)
|
549 |
+
A3_btn.click(
|
550 |
+
reset_btn_rank,
|
551 |
+
inputs=[window1_text, rank, A3_btn, vote_level],
|
552 |
+
outputs=[rank, vote_level]
|
553 |
+
).then(
|
554 |
+
add_foreground,
|
555 |
+
inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
556 |
+
outputs=[chatbot_left]
|
557 |
+
).then(
|
558 |
+
reset_submit,
|
559 |
+
inputs = [rank],
|
560 |
+
outputs = [vote_submit_btn]
|
561 |
+
).then(
|
562 |
+
reset_vote_text,
|
563 |
+
inputs = [rank],
|
564 |
+
outputs = [vote_textbox]
|
565 |
+
)
|
566 |
+
A4_btn.click(
|
567 |
+
reset_btn_rank,
|
568 |
+
inputs=[window1_text, rank, A4_btn, vote_level],
|
569 |
+
outputs=[rank, vote_level]
|
570 |
+
).then(
|
571 |
+
add_foreground,
|
572 |
+
inputs=[generate_ig0, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
573 |
+
outputs=[chatbot_left]
|
574 |
+
).then(
|
575 |
+
reset_submit,
|
576 |
+
inputs = [rank],
|
577 |
+
outputs = [vote_submit_btn]
|
578 |
+
).then(
|
579 |
+
reset_vote_text,
|
580 |
+
inputs = [rank],
|
581 |
+
outputs = [vote_textbox]
|
582 |
+
)
|
583 |
+
|
584 |
+
B1_btn.click(
|
585 |
+
reset_btn_rank,
|
586 |
+
inputs=[window2_text, rank, B1_btn, vote_level],
|
587 |
+
outputs=[rank, vote_level]
|
588 |
+
).then(
|
589 |
+
add_foreground,
|
590 |
+
inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
591 |
+
outputs=[chatbot_left1]
|
592 |
+
).then(
|
593 |
+
reset_submit,
|
594 |
+
inputs = [rank],
|
595 |
+
outputs = [vote_submit_btn]
|
596 |
+
).then(
|
597 |
+
reset_vote_text,
|
598 |
+
inputs = [rank],
|
599 |
+
outputs = [vote_textbox]
|
600 |
+
)
|
601 |
+
B2_btn.click(
|
602 |
+
reset_btn_rank,
|
603 |
+
inputs=[window2_text, rank, B2_btn, vote_level],
|
604 |
+
outputs=[rank, vote_level]
|
605 |
+
).then(
|
606 |
+
add_foreground,
|
607 |
+
inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
608 |
+
outputs=[chatbot_left1]
|
609 |
+
).then(
|
610 |
+
reset_submit,
|
611 |
+
inputs = [rank],
|
612 |
+
outputs = [vote_submit_btn]
|
613 |
+
).then(
|
614 |
+
reset_vote_text,
|
615 |
+
inputs = [rank],
|
616 |
+
outputs = [vote_textbox]
|
617 |
+
)
|
618 |
+
B3_btn.click(
|
619 |
+
reset_btn_rank,
|
620 |
+
inputs=[window2_text, rank, B3_btn, vote_level],
|
621 |
+
outputs=[rank, vote_level]
|
622 |
+
).then(
|
623 |
+
add_foreground,
|
624 |
+
inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
625 |
+
outputs=[chatbot_left1]
|
626 |
+
).then(
|
627 |
+
reset_submit,
|
628 |
+
inputs = [rank],
|
629 |
+
outputs = [vote_submit_btn]
|
630 |
+
).then(
|
631 |
+
reset_vote_text,
|
632 |
+
inputs = [rank],
|
633 |
+
outputs = [vote_textbox]
|
634 |
+
)
|
635 |
+
B4_btn.click(
|
636 |
+
reset_btn_rank,
|
637 |
+
inputs=[window2_text, rank, B4_btn, vote_level],
|
638 |
+
outputs=[rank, vote_level]
|
639 |
+
).then(
|
640 |
+
add_foreground,
|
641 |
+
inputs=[generate_ig1, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
642 |
+
outputs=[chatbot_left1]
|
643 |
+
).then(
|
644 |
+
reset_submit,
|
645 |
+
inputs = [rank],
|
646 |
+
outputs = [vote_submit_btn]
|
647 |
+
).then(
|
648 |
+
reset_vote_text,
|
649 |
+
inputs = [rank],
|
650 |
+
outputs = [vote_textbox]
|
651 |
+
)
|
652 |
+
|
653 |
+
C1_btn.click(
|
654 |
+
reset_btn_rank,
|
655 |
+
inputs=[window3_text, rank, C1_btn, vote_level],
|
656 |
+
outputs=[rank, vote_level]
|
657 |
+
).then(
|
658 |
+
add_foreground,
|
659 |
+
inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
660 |
+
outputs=[chatbot_right]
|
661 |
+
).then(
|
662 |
+
reset_submit,
|
663 |
+
inputs = [rank],
|
664 |
+
outputs = [vote_submit_btn]
|
665 |
+
).then(
|
666 |
+
reset_vote_text,
|
667 |
+
inputs = [rank],
|
668 |
+
outputs = [vote_textbox]
|
669 |
+
)
|
670 |
+
C2_btn.click(
|
671 |
+
reset_btn_rank,
|
672 |
+
inputs=[window3_text, rank, C2_btn, vote_level],
|
673 |
+
outputs=[rank, vote_level]
|
674 |
+
).then(
|
675 |
+
add_foreground,
|
676 |
+
inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
677 |
+
outputs=[chatbot_right]
|
678 |
+
).then(
|
679 |
+
reset_submit,
|
680 |
+
inputs = [rank],
|
681 |
+
outputs = [vote_submit_btn]
|
682 |
+
).then(
|
683 |
+
reset_vote_text,
|
684 |
+
inputs = [rank],
|
685 |
+
outputs = [vote_textbox]
|
686 |
+
)
|
687 |
+
C3_btn.click(
|
688 |
+
reset_btn_rank,
|
689 |
+
inputs=[window3_text, rank, C3_btn, vote_level],
|
690 |
+
outputs=[rank, vote_level]
|
691 |
+
).then(
|
692 |
+
add_foreground,
|
693 |
+
inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
694 |
+
outputs=[chatbot_right]
|
695 |
+
).then(
|
696 |
+
reset_submit,
|
697 |
+
inputs = [rank],
|
698 |
+
outputs = [vote_submit_btn]
|
699 |
+
).then(
|
700 |
+
reset_vote_text,
|
701 |
+
inputs = [rank],
|
702 |
+
outputs = [vote_textbox]
|
703 |
+
)
|
704 |
+
C4_btn.click(
|
705 |
+
reset_btn_rank,
|
706 |
+
inputs=[window3_text, rank, C4_btn, vote_level],
|
707 |
+
outputs=[rank, vote_level]
|
708 |
+
).then(
|
709 |
+
add_foreground,
|
710 |
+
inputs=[generate_ig2, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
711 |
+
outputs=[chatbot_right]
|
712 |
+
).then(
|
713 |
+
reset_submit,
|
714 |
+
inputs = [rank],
|
715 |
+
outputs = [vote_submit_btn]
|
716 |
+
).then(
|
717 |
+
reset_vote_text,
|
718 |
+
inputs = [rank],
|
719 |
+
outputs = [vote_textbox]
|
720 |
+
)
|
721 |
+
|
722 |
+
D1_btn.click(
|
723 |
+
reset_btn_rank,
|
724 |
+
inputs=[window4_text, rank, D1_btn, vote_level],
|
725 |
+
outputs=[rank, vote_level]
|
726 |
+
).then(
|
727 |
+
add_foreground,
|
728 |
+
inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
729 |
+
outputs=[chatbot_right1]
|
730 |
+
).then(
|
731 |
+
reset_submit,
|
732 |
+
inputs = [rank],
|
733 |
+
outputs = [vote_submit_btn]
|
734 |
+
).then(
|
735 |
+
reset_vote_text,
|
736 |
+
inputs = [rank],
|
737 |
+
outputs = [vote_textbox]
|
738 |
+
)
|
739 |
+
D2_btn.click(
|
740 |
+
reset_btn_rank,
|
741 |
+
inputs=[window4_text, rank, D2_btn, vote_level],
|
742 |
+
outputs=[rank, vote_level]
|
743 |
+
).then(
|
744 |
+
add_foreground,
|
745 |
+
inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
746 |
+
outputs=[chatbot_right1]
|
747 |
+
).then(
|
748 |
+
reset_submit,
|
749 |
+
inputs = [rank],
|
750 |
+
outputs = [vote_submit_btn]
|
751 |
+
).then(
|
752 |
+
reset_vote_text,
|
753 |
+
inputs = [rank],
|
754 |
+
outputs = [vote_textbox]
|
755 |
+
)
|
756 |
+
D3_btn.click(
|
757 |
+
reset_btn_rank,
|
758 |
+
inputs=[window4_text, rank, D3_btn, vote_level],
|
759 |
+
outputs=[rank, vote_level]
|
760 |
+
).then(
|
761 |
+
add_foreground,
|
762 |
+
inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
763 |
+
outputs=[chatbot_right1]
|
764 |
+
).then(
|
765 |
+
reset_submit,
|
766 |
+
inputs = [rank],
|
767 |
+
outputs = [vote_submit_btn]
|
768 |
+
).then(
|
769 |
+
reset_vote_text,
|
770 |
+
inputs = [rank],
|
771 |
+
outputs = [vote_textbox]
|
772 |
+
)
|
773 |
+
D4_btn.click(
|
774 |
+
reset_btn_rank,
|
775 |
+
inputs=[window4_text, rank, D4_btn, vote_level],
|
776 |
+
outputs=[rank, vote_level]
|
777 |
+
).then(
|
778 |
+
add_foreground,
|
779 |
+
inputs=[generate_ig3, vote_level, Top1_text, Top2_text, Top3_text, Top4_text],
|
780 |
+
outputs=[chatbot_right1]
|
781 |
+
).then(
|
782 |
+
reset_submit,
|
783 |
+
inputs = [rank],
|
784 |
+
outputs = [vote_submit_btn]
|
785 |
+
).then(
|
786 |
+
reset_vote_text,
|
787 |
+
inputs = [rank],
|
788 |
+
outputs = [vote_textbox]
|
789 |
+
)
|
serve/gradio_web_bbox.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from gradio import processing_utils
|
5 |
+
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
|
7 |
+
from serve.utils import *
|
8 |
+
from serve.vote_utils import (
|
9 |
+
upvote_last_response_ig as upvote_last_response,
|
10 |
+
downvote_last_response_ig as downvote_last_response,
|
11 |
+
flag_last_response_ig as flag_last_response,
|
12 |
+
leftvote_last_response_igm as leftvote_last_response,
|
13 |
+
left1vote_last_response_igm as left1vote_last_response,
|
14 |
+
rightvote_last_response_igm as rightvote_last_response,
|
15 |
+
right1vote_last_response_igm as right1vote_last_response,
|
16 |
+
tievote_last_response_igm as tievote_last_response,
|
17 |
+
bothbad_vote_last_response_igm as bothbad_vote_last_response,
|
18 |
+
share_click_igm as share_click,
|
19 |
+
generate_ig,
|
20 |
+
generate_ig_museum,
|
21 |
+
generate_igm,
|
22 |
+
generate_igm_museum,
|
23 |
+
generate_igm_annoy,
|
24 |
+
generate_igm_annoy_museum,
|
25 |
+
generate_igm_cache_annoy,
|
26 |
+
share_js
|
27 |
+
)
|
28 |
+
from serve.Ksort import (
|
29 |
+
add_foreground,
|
30 |
+
reset_level,
|
31 |
+
reset_rank,
|
32 |
+
revote_windows,
|
33 |
+
submit_response_igm,
|
34 |
+
submit_response_rank_igm,
|
35 |
+
reset_submit,
|
36 |
+
clear_rank,
|
37 |
+
reset_mode,
|
38 |
+
reset_chatbot,
|
39 |
+
reset_btn_rank,
|
40 |
+
reset_vote_text,
|
41 |
+
text_response_rank_igm,
|
42 |
+
check_textbox,
|
43 |
+
)
|
44 |
+
|
45 |
+
from functools import partial
|
46 |
+
from serve.upload import get_random_mscoco_prompt
|
47 |
+
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD
|
48 |
+
from serve.upload import get_random_mscoco_prompt, create_ssh_client
|
49 |
+
from serve.update_skill import create_ssh_skill_client
|
50 |
+
from model.matchmaker import create_ssh_matchmaker_client
|
51 |
+
|
52 |
+
def set_ssh():
|
53 |
+
create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
54 |
+
create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
55 |
+
create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
56 |
+
|
57 |
+
def binarize(x):
|
58 |
+
return (x != 0).astype('uint8') * 255
|
59 |
+
|
60 |
+
def sized_center_crop(img, cropx, cropy):
|
61 |
+
y, x = img.shape[:2]
|
62 |
+
startx = x // 2 - (cropx // 2)
|
63 |
+
starty = y // 2 - (cropy // 2)
|
64 |
+
return img[starty:starty+cropy, startx:startx+cropx]
|
65 |
+
|
66 |
+
def sized_center_fill(img, fill, cropx, cropy):
|
67 |
+
y, x = img.shape[:2]
|
68 |
+
startx = x // 2 - (cropx // 2)
|
69 |
+
starty = y // 2 - (cropy // 2)
|
70 |
+
img[starty:starty+cropy, startx:startx+cropx] = fill
|
71 |
+
return img
|
72 |
+
|
73 |
+
def sized_center_mask(img, cropx, cropy):
|
74 |
+
y, x = img.shape[:2]
|
75 |
+
startx = x // 2 - (cropx // 2)
|
76 |
+
starty = y // 2 - (cropy // 2)
|
77 |
+
center_region = img[starty:starty+cropy, startx:startx+cropx].copy()
|
78 |
+
img = (img * 0.2).astype('uint8')
|
79 |
+
img[starty:starty+cropy, startx:startx+cropx] = center_region
|
80 |
+
return img
|
81 |
+
|
82 |
+
def center_crop(img, HW=None, tgt_size=(512, 512)):
|
83 |
+
if HW is None:
|
84 |
+
H, W = img.shape[:2]
|
85 |
+
HW = min(H, W)
|
86 |
+
img = sized_center_crop(img, HW, HW)
|
87 |
+
img = Image.fromarray(img)
|
88 |
+
img = img.resize(tgt_size)
|
89 |
+
return np.array(img)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
def draw_box(boxes=[], labels=[], img=None):
|
94 |
+
if len(boxes) == 0 and img is None:
|
95 |
+
return None
|
96 |
+
|
97 |
+
# 确保输入的 img 是 PIL.Image 对象
|
98 |
+
if isinstance(img, np.ndarray):
|
99 |
+
img = Image.fromarray(img)
|
100 |
+
|
101 |
+
if img is None:
|
102 |
+
img = Image.new('RGB', (512, 512), (255, 255, 255))
|
103 |
+
|
104 |
+
colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
|
105 |
+
|
106 |
+
# 创建绘图对象
|
107 |
+
draw_obj = ImageDraw.Draw(img)
|
108 |
+
font = ImageFont.load_default()
|
109 |
+
# 获取字体大小
|
110 |
+
font_size = getattr(font, 'size_in_points', 10) # 如果没有 size_in_points 属性,使用默认值 10
|
111 |
+
|
112 |
+
for bid, box in enumerate(boxes):
|
113 |
+
draw_obj.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
|
114 |
+
anno_text = labels[bid]
|
115 |
+
draw_obj.rectangle(
|
116 |
+
[
|
117 |
+
box[0],
|
118 |
+
box[3] - int(font_size * 1.2),
|
119 |
+
box[0] + int((len(anno_text) + 0.8) * font_size * 0.6),
|
120 |
+
box[3]
|
121 |
+
],
|
122 |
+
outline=colors[bid % len(colors)],
|
123 |
+
fill=colors[bid % len(colors)],
|
124 |
+
width=4
|
125 |
+
)
|
126 |
+
draw_obj.text(
|
127 |
+
[box[0] + int(font_size * 0.2), box[3] - int(font_size * 1.2)],
|
128 |
+
anno_text,
|
129 |
+
font=font,
|
130 |
+
fill=(255,255,255)
|
131 |
+
)
|
132 |
+
return img
|
133 |
+
|
134 |
+
|
135 |
+
def draw(input, grounding_texts, new_image_trigger, state):
|
136 |
+
|
137 |
+
# 确保输入数据中有必要的键
|
138 |
+
if isinstance(input, dict):
|
139 |
+
background = input.get('background', None)
|
140 |
+
print("background.shape", background.shape)
|
141 |
+
layers = input.get('layers', 1)
|
142 |
+
print("len(layers)", len(layers))
|
143 |
+
composite = input.get('composite', None)
|
144 |
+
print("composite.shape", composite.shape)
|
145 |
+
else:
|
146 |
+
# 如果 input 不是字典,直接使用其作为图像
|
147 |
+
background = input
|
148 |
+
layers = 1
|
149 |
+
composite = None
|
150 |
+
|
151 |
+
# 检查 background 是否有效
|
152 |
+
if background is None:
|
153 |
+
print("background is None")
|
154 |
+
return None
|
155 |
+
# background = np.ones((512, 512, 3), dtype='uint8') * 255
|
156 |
+
|
157 |
+
# 默认使用 composite 作为最终图像,如果没有 composite 则使用 background
|
158 |
+
if composite is None:
|
159 |
+
print("composite is None")
|
160 |
+
image = background
|
161 |
+
else:
|
162 |
+
image = composite
|
163 |
+
|
164 |
+
mask = binarize(image)
|
165 |
+
|
166 |
+
if type(mask) != np.ndarray:
|
167 |
+
mask = np.array(mask)
|
168 |
+
|
169 |
+
if mask.sum() == 0:
|
170 |
+
state = {}
|
171 |
+
|
172 |
+
# 更新状态,如果没有 boxes 和 masks,则初始化它们
|
173 |
+
if 'boxes' not in state:
|
174 |
+
state['boxes'] = []
|
175 |
+
|
176 |
+
if 'masks' not in state or len(state['masks']) == 0:
|
177 |
+
state['masks'] = []
|
178 |
+
last_mask = np.zeros_like(mask)
|
179 |
+
else:
|
180 |
+
last_mask = state['masks'][-1]
|
181 |
+
|
182 |
+
if type(mask) == np.ndarray and mask.size > 1:
|
183 |
+
diff_mask = mask - last_mask
|
184 |
+
else:
|
185 |
+
diff_mask = np.zeros([])
|
186 |
+
|
187 |
+
# 根据 mask 的变化来计算 box 的位置
|
188 |
+
if diff_mask.sum() > 0:
|
189 |
+
x1x2 = np.where(diff_mask.max(0) != 0)[0]
|
190 |
+
y1y2 = np.where(diff_mask.max(1) != 0)[0]
|
191 |
+
y1, y2 = y1y2.min(), y1y2.max()
|
192 |
+
x1, x2 = x1x2.min(), x1x2.max()
|
193 |
+
|
194 |
+
if (x2 - x1 > 5) and (y2 - y1 > 5):
|
195 |
+
state['masks'].append(mask.copy())
|
196 |
+
state['boxes'].append((x1, y1, x2, y2))
|
197 |
+
|
198 |
+
# 处理 grounding_texts
|
199 |
+
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
|
200 |
+
grounding_texts = [x for x in grounding_texts if len(x) > 0]
|
201 |
+
if len(grounding_texts) < len(state['boxes']):
|
202 |
+
grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
|
203 |
+
|
204 |
+
# 绘制标注框
|
205 |
+
box_image = draw_box(state['boxes'], grounding_texts, background)
|
206 |
+
|
207 |
+
if box_image is not None and state.get('inpaint_hw', None):
|
208 |
+
inpaint_hw = state['inpaint_hw']
|
209 |
+
box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
|
210 |
+
original_image = state['original_image'].copy()
|
211 |
+
box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
|
212 |
+
|
213 |
+
return [box_image, new_image_trigger, 1.0, state]
|
214 |
+
|
215 |
+
def build_side_by_side_bbox_ui_anony(models):
|
216 |
+
notice_markdown = """
|
217 |
+
# ⚔️ Control-Ability-Arena (Bbox-to-Image Generation) ⚔️
|
218 |
+
## 📜 Rules
|
219 |
+
- Input a prompt for four anonymized models and vote on their outputs.
|
220 |
+
- Two voting modes available: Rank Mode and Best Mode. Switch freely between modes. Please note that ties are always allowed. In ranking mode, users can input rankings like 1 3 3 1. Any invalid rankings, such as 1 4 4 1, will be automatically corrected during post-processing.
|
221 |
+
- Users are encouraged to make evaluations based on subjective preferences. Evaluation criteria: Alignment (50%) + Aesthetics (50%).
|
222 |
+
- Alignment includes: Entity Matching (30%) + Style Matching (20%);
|
223 |
+
- Aesthetics includes: Photorealism (30%) + Light and Shadow (10%) + Absence of Artifacts (10%).
|
224 |
+
|
225 |
+
## 👇 Generating now!
|
226 |
+
- Note: Due to the API's image safety checks, errors may occur. If this happens, please re-enter a different prompt.
|
227 |
+
- At times, high API concurrency can cause congestion, potentially resulting in a generation time of up to 1.5 minutes per image. Thank you for your patience.
|
228 |
+
"""
|
229 |
+
model_list = models.model_b2i_list
|
230 |
+
|
231 |
+
state = gr.State({})
|
232 |
+
state0 = gr.State()
|
233 |
+
state1 = gr.State()
|
234 |
+
state2 = gr.State()
|
235 |
+
state3 = gr.State()
|
236 |
+
|
237 |
+
# gen_func = partial(generate_igm_annoy, models.generate_image_ig_parallel_anony)
|
238 |
+
# gen_cache_func = partial(generate_igm_cache_annoy, models.generate_image_ig_cache_anony)
|
239 |
+
|
240 |
+
|
241 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
242 |
+
|
243 |
+
|
244 |
+
with gr.Row():
|
245 |
+
sketch_pad_trigger = gr.Number(value=0, visible=False)
|
246 |
+
sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
|
247 |
+
image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
|
248 |
+
|
249 |
+
with gr.Row():
|
250 |
+
sketch_pad = gr.ImageEditor(
|
251 |
+
label="Sketch Pad",
|
252 |
+
type="numpy",
|
253 |
+
crop_size="1:1",
|
254 |
+
width=512,
|
255 |
+
height=512
|
256 |
+
)
|
257 |
+
out_imagebox = gr.Image(
|
258 |
+
type="pil",
|
259 |
+
label="Parsed Sketch Pad",
|
260 |
+
width=512,
|
261 |
+
height=512
|
262 |
+
)
|
263 |
+
|
264 |
+
with gr.Row():
|
265 |
+
textbox = gr.Textbox(
|
266 |
+
show_label=False,
|
267 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
268 |
+
container=True,
|
269 |
+
elem_id="input_box",
|
270 |
+
)
|
271 |
+
send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue")
|
272 |
+
|
273 |
+
with gr.Row():
|
274 |
+
grounding_instruction = gr.Textbox(
|
275 |
+
label="Grounding instruction (Separated by semicolon)",
|
276 |
+
placeholder="👉 Enter your Grounding instruction (e.g. a cat; a dog; a bird; a fish)",
|
277 |
+
)
|
278 |
+
|
279 |
+
with gr.Group(elem_id="share-region-anony"):
|
280 |
+
with gr.Accordion("🔍 Expand to see all Arena players", open=False):
|
281 |
+
# model_description_md = get_model_description_md(model_list)
|
282 |
+
gr.Markdown("", elem_id="model_description_markdown")
|
283 |
+
|
284 |
+
|
285 |
+
with gr.Row():
|
286 |
+
with gr.Column():
|
287 |
+
chatbot_left = gr.Image(width=512, label = "Model A")
|
288 |
+
with gr.Column():
|
289 |
+
chatbot_left1 = gr.Image(width=512, label = "Model B")
|
290 |
+
with gr.Column():
|
291 |
+
chatbot_right = gr.Image(width=512, label = "Model C")
|
292 |
+
with gr.Column():
|
293 |
+
chatbot_right1 = gr.Image(width=512, label = "Model D")
|
294 |
+
|
295 |
+
with gr.Row():
|
296 |
+
with gr.Column():
|
297 |
+
model_selector_left = gr.Markdown("", visible=False)
|
298 |
+
with gr.Column():
|
299 |
+
model_selector_left1 = gr.Markdown("", visible=False)
|
300 |
+
with gr.Column():
|
301 |
+
model_selector_right = gr.Markdown("", visible=False)
|
302 |
+
with gr.Column():
|
303 |
+
model_selector_right1 = gr.Markdown("", visible=False)
|
304 |
+
with gr.Row():
|
305 |
+
slow_warning = gr.Markdown("", elem_id="notice_markdown")
|
306 |
+
|
307 |
+
with gr.Row(elem_classes="row"):
|
308 |
+
with gr.Column(scale=1, min_width=10):
|
309 |
+
leftvote_btn = gr.Button(
|
310 |
+
value="A is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
311 |
+
)
|
312 |
+
with gr.Column(scale=1, min_width=10):
|
313 |
+
left1vote_btn = gr.Button(
|
314 |
+
value="B is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
315 |
+
)
|
316 |
+
with gr.Column(scale=1, min_width=10):
|
317 |
+
rightvote_btn = gr.Button(
|
318 |
+
value="C is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
319 |
+
)
|
320 |
+
with gr.Column(scale=1, min_width=10):
|
321 |
+
right1vote_btn = gr.Button(
|
322 |
+
value="D is Best", visible=False, interactive=False, elem_id="btncolor1", elem_classes="best-button"
|
323 |
+
)
|
324 |
+
with gr.Column(scale=1, min_width=10):
|
325 |
+
tie_btn = gr.Button(
|
326 |
+
value="🤝 Tie", visible=False, interactive=False, elem_id="btncolor2", elem_classes="best-button"
|
327 |
+
)
|
328 |
+
|
329 |
+
|
330 |
+
with gr.Row():
|
331 |
+
with gr.Blocks():
|
332 |
+
with gr.Row():
|
333 |
+
with gr.Column(scale=1, min_width=10):
|
334 |
+
A1_btn = gr.Button(
|
335 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
336 |
+
)
|
337 |
+
with gr.Column(scale=1, min_width=10):
|
338 |
+
A2_btn = gr.Button(
|
339 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
340 |
+
)
|
341 |
+
with gr.Column(scale=1, min_width=10):
|
342 |
+
A3_btn = gr.Button(
|
343 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
344 |
+
)
|
345 |
+
with gr.Column(scale=1, min_width=10):
|
346 |
+
A4_btn = gr.Button(
|
347 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
348 |
+
)
|
349 |
+
with gr.Blocks():
|
350 |
+
with gr.Row():
|
351 |
+
with gr.Column(scale=1, min_width=10):
|
352 |
+
B1_btn = gr.Button(
|
353 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
354 |
+
)
|
355 |
+
with gr.Column(scale=1, min_width=10):
|
356 |
+
B2_btn = gr.Button(
|
357 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
358 |
+
)
|
359 |
+
with gr.Column(scale=1, min_width=10):
|
360 |
+
B3_btn = gr.Button(
|
361 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
362 |
+
)
|
363 |
+
with gr.Column(scale=1, min_width=10):
|
364 |
+
B4_btn = gr.Button(
|
365 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
366 |
+
)
|
367 |
+
with gr.Blocks():
|
368 |
+
with gr.Row():
|
369 |
+
with gr.Column(scale=1, min_width=10):
|
370 |
+
C1_btn = gr.Button(
|
371 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
372 |
+
)
|
373 |
+
with gr.Column(scale=1, min_width=10):
|
374 |
+
C2_btn = gr.Button(
|
375 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
376 |
+
)
|
377 |
+
with gr.Column(scale=1, min_width=10):
|
378 |
+
C3_btn = gr.Button(
|
379 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
380 |
+
)
|
381 |
+
with gr.Column(scale=1, min_width=10):
|
382 |
+
C4_btn = gr.Button(
|
383 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
384 |
+
)
|
385 |
+
with gr.Blocks():
|
386 |
+
with gr.Row():
|
387 |
+
with gr.Column(scale=1, min_width=10):
|
388 |
+
D1_btn = gr.Button(
|
389 |
+
value="1", visible=False, interactive=False, elem_id="btncolor1", elem_classes="custom-button"
|
390 |
+
)
|
391 |
+
with gr.Column(scale=1, min_width=10):
|
392 |
+
D2_btn = gr.Button(
|
393 |
+
value="2", visible=False, interactive=False, elem_id="btncolor2", elem_classes="custom-button"
|
394 |
+
)
|
395 |
+
with gr.Column(scale=1, min_width=10):
|
396 |
+
D3_btn = gr.Button(
|
397 |
+
value="3", visible=False, interactive=False, elem_id="btncolor3", elem_classes="custom-button"
|
398 |
+
)
|
399 |
+
with gr.Column(scale=1, min_width=10):
|
400 |
+
D4_btn = gr.Button(
|
401 |
+
value="4", visible=False, interactive=False, elem_id="btncolor4", elem_classes="custom-button"
|
402 |
+
)
|
403 |
+
with gr.Row():
|
404 |
+
vote_textbox = gr.Textbox(
|
405 |
+
show_label=False,
|
406 |
+
placeholder="👉 Enter your rank (you can use buttons above, or directly type here, e.g. 1 2 3 4)",
|
407 |
+
container=True,
|
408 |
+
elem_id="input_box",
|
409 |
+
visible=False,
|
410 |
+
)
|
411 |
+
vote_submit_btn = gr.Button(value="Submit", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
|
412 |
+
vote_mode_btn = gr.Button(value="🔄 Mode", visible=False, interactive=False, variant="primary", scale=0, elem_id="btnpink", elem_classes="submit-button")
|
413 |
+
|
414 |
+
with gr.Row():
|
415 |
+
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
|
416 |
+
# regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
417 |
+
# share_btn = gr.Button(value="📷 Share")
|
418 |
+
with gr.Blocks():
|
419 |
+
with gr.Row(elem_id="centered-text"): #
|
420 |
+
user_info = gr.Markdown("User information (to appear on the contributor leaderboard)", visible=True, elem_id="centered-text") #, elem_id="centered-text"
|
421 |
+
# with gr.Blocks():
|
422 |
+
# name = gr.Markdown("Name", visible=True)
|
423 |
+
user_name = gr.Textbox(show_label=False,placeholder="👉 Enter your name (optional)", elem_classes="custom-width")
|
424 |
+
# with gr.Blocks():
|
425 |
+
# institution = gr.Markdown("Institution", visible=True)
|
426 |
+
user_institution = gr.Textbox(show_label=False,placeholder="👉 Enter your affiliation (optional)", elem_classes="custom-width")
|
427 |
+
|
428 |
+
sketch_pad.change(
|
429 |
+
draw,
|
430 |
+
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
|
431 |
+
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
|
432 |
+
queue=False,
|
433 |
+
)
|
434 |
+
grounding_instruction.change(
|
435 |
+
draw,
|
436 |
+
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
|
437 |
+
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
|
438 |
+
queue=False,
|
439 |
+
)
|
440 |
+
|
441 |
+
order_btn_list = [textbox, send_btn, clear_btn]
|
442 |
+
vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
|
443 |
+
A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
|
444 |
+
vote_textbox, vote_submit_btn, vote_mode_btn]
|
445 |
+
|
446 |
+
generate_ig0 = gr.Image(width=512, label = "generate A", visible=False, interactive=False)
|
447 |
+
generate_ig1 = gr.Image(width=512, label = "generate B", visible=False, interactive=False)
|
448 |
+
generate_ig2 = gr.Image(width=512, label = "generate C", visible=False, interactive=False)
|
449 |
+
generate_ig3 = gr.Image(width=512, label = "generate D", visible=False, interactive=False)
|
450 |
+
dummy_left_model = gr.State("")
|
451 |
+
dummy_left1_model = gr.State("")
|
452 |
+
dummy_right_model = gr.State("")
|
453 |
+
dummy_right1_model = gr.State("")
|
454 |
+
|
455 |
+
ig_rank = [None, None, None, None]
|
456 |
+
bastA_rank = [0, 3, 3, 3]
|
457 |
+
bastB_rank = [3, 0, 3, 3]
|
458 |
+
bastC_rank = [3, 3, 0, 3]
|
459 |
+
bastD_rank = [3, 3, 3, 0]
|
460 |
+
tie_rank = [0, 0, 0, 0]
|
461 |
+
bad_rank = [3, 3, 3, 3]
|
462 |
+
rank = gr.State(ig_rank)
|
463 |
+
rankA = gr.State(bastA_rank)
|
464 |
+
rankB = gr.State(bastB_rank)
|
465 |
+
rankC = gr.State(bastC_rank)
|
466 |
+
rankD = gr.State(bastD_rank)
|
467 |
+
rankTie = gr.State(tie_rank)
|
468 |
+
rankBad = gr.State(bad_rank)
|
469 |
+
Top1_text = gr.Textbox(value="Top 1", visible=False, interactive=False)
|
470 |
+
Top2_text = gr.Textbox(value="Top 2", visible=False, interactive=False)
|
471 |
+
Top3_text = gr.Textbox(value="Top 3", visible=False, interactive=False)
|
472 |
+
Top4_text = gr.Textbox(value="Top 4", visible=False, interactive=False)
|
473 |
+
window1_text = gr.Textbox(value="Model A", visible=False, interactive=False)
|
474 |
+
window2_text = gr.Textbox(value="Model B", visible=False, interactive=False)
|
475 |
+
window3_text = gr.Textbox(value="Model C", visible=False, interactive=False)
|
476 |
+
window4_text = gr.Textbox(value="Model D", visible=False, interactive=False)
|
477 |
+
vote_level = gr.Number(value=0, visible=False, interactive=False)
|
478 |
+
# Top1_btn.click(reset_level, inputs=[Top1_text], outputs=[vote_level])
|
479 |
+
# Top2_btn.click(reset_level, inputs=[Top2_text], outputs=[vote_level])
|
480 |
+
# Top3_btn.click(reset_level, inputs=[Top3_text], outputs=[vote_level])
|
481 |
+
# Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level])
|
482 |
+
vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False)
|
483 |
+
right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False)
|
484 |
+
cache_mode = gr.Textbox(value="True", visible=False, interactive=False)
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
if __name__ == "__main__":
|
490 |
+
with gr.Blocks() as demo:
|
491 |
+
build_side_by_side_bbox_ui_anony()
|
492 |
+
demo.launch()
|
serve/leaderboard.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Live monitor of the website statistics and leaderboard.
|
3 |
+
|
4 |
+
Dependency:
|
5 |
+
sudo apt install pkg-config libicu-dev
|
6 |
+
pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import ast
|
11 |
+
import pickle
|
12 |
+
import os
|
13 |
+
import threading
|
14 |
+
import time
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import numpy as np
|
18 |
+
import pandas as pd
|
19 |
+
import json
|
20 |
+
from datetime import datetime
|
21 |
+
|
22 |
+
|
23 |
+
# def make_leaderboard_md(elo_results):
|
24 |
+
# leaderboard_md = f"""
|
25 |
+
# # 🏆 Chatbot Arena Leaderboard
|
26 |
+
# | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
27 |
+
|
28 |
+
# This leaderboard is based on the following three benchmarks.
|
29 |
+
# - [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings.
|
30 |
+
# - [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses.
|
31 |
+
# - [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks.
|
32 |
+
|
33 |
+
# 💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023.
|
34 |
+
# """
|
35 |
+
# return leaderboard_md
|
36 |
+
|
37 |
+
def make_leaderboard_md():
|
38 |
+
leaderboard_md = f"""
|
39 |
+
# 🏆 K-Sort Arena Leaderboard (Text-to-Image Generation)
|
40 |
+
"""
|
41 |
+
return leaderboard_md
|
42 |
+
|
43 |
+
|
44 |
+
def make_leaderboard_video_md():
|
45 |
+
leaderboard_md = f"""
|
46 |
+
# 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
|
47 |
+
"""
|
48 |
+
return leaderboard_md
|
49 |
+
|
50 |
+
|
51 |
+
def model_hyperlink(model_name, link):
|
52 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
53 |
+
|
54 |
+
|
55 |
+
def make_arena_leaderboard_md(total_models, total_votes, last_updated):
|
56 |
+
# last_updated = datetime.now()
|
57 |
+
# last_updated = last_updated.strftime("%Y-%m-%d")
|
58 |
+
|
59 |
+
leaderboard_md = f"""
|
60 |
+
Total models: **{total_models}** (anonymized), Total votes: **{total_votes}** (equivalent to **{total_votes*6}** pairwise comparisons)
|
61 |
+
\n Last updated: {last_updated}
|
62 |
+
"""
|
63 |
+
|
64 |
+
return leaderboard_md
|
65 |
+
|
66 |
+
|
67 |
+
def make_disclaimer_md():
|
68 |
+
disclaimer_md = '''
|
69 |
+
<div id="modal" style="display:none; position:fixed; top:50%; left:50%; transform:translate(-50%, -50%); padding:20px; background:white; box-shadow:0 0 10px rgba(0,0,0,0.5); z-index:1000;">
|
70 |
+
<p style="font-size:24px;"><strong>Disclaimer</strong></p>
|
71 |
+
<p style="font-size:18px;"><b>Purpose and Scope</b></b></p>
|
72 |
+
<p><b>This platform is designed for academic use, providing a space for evaluating and comparing Visual Generation Models. The information and services provided are intended for research and educational purposes only.</b></p>
|
73 |
+
|
74 |
+
<p style="font-size:18px;"><b>Privacy and Data Protection</b></p>
|
75 |
+
<p><b>While users may voluntarily submit their names and institutional affiliations, this information is not required and is collected solely for the purpose of academic recognition. Personal information submitted to this platform will be handled with care and used solely for the intended academic purposes. We are committed to protecting your privacy, and we will not share personal data with third parties without explicit consent.</b></p>
|
76 |
+
|
77 |
+
<p style="font-size:18px;"><b>Source of Models</b></p>
|
78 |
+
<p><b>All models evaluated and displayed on this platform are obtained from official sources, including but not limited to official repositories and Replicate.</b></p>
|
79 |
+
|
80 |
+
<p style="font-size:18px;"><b>Limitations of Liability</b></p>
|
81 |
+
<p><b>The platform and its administrators do not assume any legal liability for the use or interpretation of the information provided. The evaluations and comparisons are for academic purposes. Users should verify the information independently and must not use the platform for any illegal, harmful, violent, racist, or sexual purposes.</b></p>
|
82 |
+
|
83 |
+
<p style="font-size:18px;"><b>Modification of Terms</b></p>
|
84 |
+
<p><b>We reserve the right to modify these terms at any time. Users will be notified of significant changes through updates on the platform.</b></p>
|
85 |
+
|
86 |
+
<p style="font-size:18px;"><b>Contact Information</b></p>
|
87 |
+
<p><b>For any questions or to report issues, please contact us at [email protected].</b></p>
|
88 |
+
</div>
|
89 |
+
<div id="overlay" style="display:none; position:fixed; top:0; left:0; width:100%; height:100%; background:rgba(0,0,0,0.5); z-index:999;" onclick="document.getElementById('modal').style.display='none'; document.getElementById('overlay').style.display='none'"></div>
|
90 |
+
<p> This platform is designed for academic usage, for details please refer to <a href="#" id="open_link" onclick="document.getElementById('modal').style.display='block'; document.getElementById('overlay').style.display='block'">disclaimer</a>.</p>
|
91 |
+
'''
|
92 |
+
return disclaimer_md
|
93 |
+
|
94 |
+
|
95 |
+
def make_arena_leaderboard_data(results):
|
96 |
+
import pandas as pd
|
97 |
+
df = pd.DataFrame(results)
|
98 |
+
return df
|
99 |
+
|
100 |
+
|
101 |
+
def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
|
102 |
+
with open(score_result_file, "r") as json_file:
|
103 |
+
data = json.load(json_file)
|
104 |
+
score_results = data["sorted_score_list"]
|
105 |
+
total_models = data["total_models"]
|
106 |
+
total_votes = data["total_votes"]
|
107 |
+
last_updated = data["last_updated"]
|
108 |
+
|
109 |
+
md = make_leaderboard_md()
|
110 |
+
md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
|
111 |
+
|
112 |
+
# with gr.Tab("Arena Score", id=0):
|
113 |
+
md = make_arena_leaderboard_md(total_models, total_votes, last_updated)
|
114 |
+
gr.Markdown(md, elem_id="leaderboard_markdown")
|
115 |
+
md = make_arena_leaderboard_data(score_results)
|
116 |
+
gr.Dataframe(md)
|
117 |
+
|
118 |
+
gr.Markdown(
|
119 |
+
"""
|
120 |
+
- Note: When σ is large (we use the '*' labeling), it indicates that the model did not receive enough votes and its ranking is in the process of being updated.
|
121 |
+
""",
|
122 |
+
elem_id="sigma_note_markdown",
|
123 |
+
)
|
124 |
+
|
125 |
+
gr.Markdown(
|
126 |
+
""" ### The leaderboard is regularly updated and continuously incorporates new models.
|
127 |
+
""",
|
128 |
+
elem_id="leaderboard_markdown",
|
129 |
+
)
|
130 |
+
with gr.Blocks():
|
131 |
+
gr.HTML(make_disclaimer_md)
|
132 |
+
from .utils import acknowledgment_md, html_code
|
133 |
+
with gr.Blocks():
|
134 |
+
gr.Markdown(acknowledgment_md)
|
135 |
+
|
136 |
+
|
137 |
+
def build_leaderboard_video_tab(score_result_file = 'sorted_score_list_video.json'):
|
138 |
+
with open(score_result_file, "r") as json_file:
|
139 |
+
data = json.load(json_file)
|
140 |
+
score_results = data["sorted_score_list"]
|
141 |
+
total_models = data["total_models"]
|
142 |
+
total_votes = data["total_votes"]
|
143 |
+
last_updated = data["last_updated"]
|
144 |
+
|
145 |
+
md = make_leaderboard_video_md()
|
146 |
+
md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
|
147 |
+
# with gr.Blocks():
|
148 |
+
# gr.HTML(make_disclaimer_md)
|
149 |
+
|
150 |
+
# with gr.Tab("Arena Score", id=0):
|
151 |
+
md = make_arena_leaderboard_md(total_models, total_votes, last_updated)
|
152 |
+
gr.Markdown(md, elem_id="leaderboard_markdown")
|
153 |
+
md = make_arena_leaderboard_data(score_results)
|
154 |
+
gr.Dataframe(md)
|
155 |
+
|
156 |
+
notice_markdown_sora = """
|
157 |
+
- Note: When σ is large (we use the '*' labeling), it indicates that the model did not receive enough votes and its ranking is in the process of being updated.
|
158 |
+
- Note: As Sora's video generation function is not publicly available, we used sample videos from their official website. This may lead to a biased assessment of Sora's capabilities, as these samples likely represent Sora's best outputs. Therefore, Sora's position on our leaderboard should be considered as its upper bound. We are working on methods to conduct more comprehensive and fair comparisons in the future.
|
159 |
+
"""
|
160 |
+
|
161 |
+
gr.Markdown(notice_markdown_sora, elem_id="notice_markdown_sora")
|
162 |
+
|
163 |
+
gr.Markdown(
|
164 |
+
""" ### The leaderboard is regularly updated and continuously incorporates new models.
|
165 |
+
""",
|
166 |
+
elem_id="leaderboard_markdown",
|
167 |
+
)
|
168 |
+
from .utils import acknowledgment_md, html_code
|
169 |
+
with gr.Blocks():
|
170 |
+
gr.Markdown(acknowledgment_md)
|
171 |
+
|
172 |
+
|
173 |
+
def build_leaderboard_contributor(file = 'contributor.json'):
|
174 |
+
|
175 |
+
with open(file, "r") as json_file:
|
176 |
+
data = json.load(json_file)
|
177 |
+
score_results = data["contributor"]
|
178 |
+
last_updated = data["last_updated"]
|
179 |
+
|
180 |
+
md = f"""
|
181 |
+
# 🏆 Contributor Leaderboard
|
182 |
+
The submission of user information is entirely optional. This information is used solely for contribution statistics. We respect and safeguard users' privacy choices.
|
183 |
+
To maintain a clean and concise leaderboard, please ensure consistency in submitted names and affiliations. For example, use 'Berkeley' consistently rather than alternating with 'UC Berkeley'.
|
184 |
+
- Votes*: Each image vote counts as one Vote*, while each video vote counts as two Votes* due to the increased effort involved.
|
185 |
+
\n Last updated: {last_updated}
|
186 |
+
"""
|
187 |
+
|
188 |
+
md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
|
189 |
+
|
190 |
+
# md = make_arena_leaderboard_md(total_models, total_votes, last_updated)
|
191 |
+
# gr.Markdown(md, elem_id="leaderboard_markdown")
|
192 |
+
|
193 |
+
md = make_arena_leaderboard_data(score_results)
|
194 |
+
gr.Dataframe(md)
|
195 |
+
|
196 |
+
gr.Markdown(
|
197 |
+
""" ### The leaderboard is regularly updated.
|
198 |
+
""",
|
199 |
+
elem_id="leaderboard_markdown",
|
200 |
+
)
|
serve/log_server.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Form, APIRouter
|
2 |
+
from typing import Optional
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import aiofiles
|
6 |
+
from .log_utils import build_logger
|
7 |
+
from .constants import LOG_SERVER_SUBDOMAIN, APPEND_JSON, SAVE_IMAGE, SAVE_VIDEO, SAVE_LOG
|
8 |
+
|
9 |
+
logger = build_logger("log_server", "log_server.log", add_remote_handler=False)
|
10 |
+
|
11 |
+
app = APIRouter(prefix=LOG_SERVER_SUBDOMAIN)
|
12 |
+
|
13 |
+
@app.post(f"/{APPEND_JSON}")
|
14 |
+
async def append_json(json_str: str = Form(...), file_name: str = Form(...)):
|
15 |
+
"""
|
16 |
+
Appends a JSON string to a specified file.
|
17 |
+
"""
|
18 |
+
# Convert the string back to a JSON object (dict)
|
19 |
+
data = json.loads(json_str)
|
20 |
+
# Append the data to the specified file
|
21 |
+
if os.path.dirname(file_name):
|
22 |
+
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
23 |
+
async with aiofiles.open(file_name, mode='a') as f:
|
24 |
+
await f.write(json.dumps(data) + "\n")
|
25 |
+
|
26 |
+
logger.info(f"Appended 1 JSON object to {file_name}")
|
27 |
+
return {"message": "JSON data appended successfully"}
|
28 |
+
|
29 |
+
@app.post(f"/{SAVE_IMAGE}")
|
30 |
+
async def save_image(image: UploadFile = File(...), image_path: str = Form(...)):
|
31 |
+
"""
|
32 |
+
Saves an uploaded image to the specified path.
|
33 |
+
"""
|
34 |
+
# Note: 'image_path' should include the file name and extension for the image to be saved.
|
35 |
+
if os.path.dirname(image_path):
|
36 |
+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
37 |
+
async with aiofiles.open(image_path, mode='wb') as f:
|
38 |
+
content = await image.read() # Read the content of the uploaded image
|
39 |
+
await f.write(content) # Write the image content to a file
|
40 |
+
logger.info(f"Image saved successfully at {image_path}")
|
41 |
+
return {"message": f"Image saved successfully at {image_path}"}
|
42 |
+
|
43 |
+
@app.post(f"/{SAVE_VIDEO}")
|
44 |
+
async def save_video(video: UploadFile = File(...), video_path: str = Form(...)):
|
45 |
+
"""
|
46 |
+
Saves an uploaded video to the specified path.
|
47 |
+
"""
|
48 |
+
# Note: 'video_path' should include the file name and extension for the video to be saved.
|
49 |
+
if os.path.dirname(video_path):
|
50 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
51 |
+
async with aiofiles.open(video_path, mode='wb') as f:
|
52 |
+
content = await video.read() # Read the content of the uploaded video
|
53 |
+
await f.write(content) # Write the video content to a file
|
54 |
+
logger.info(f"Video saved successfully at {video_path}")
|
55 |
+
return {"message": f"Image saved successfully at {video_path}"}
|
56 |
+
|
57 |
+
@app.post(f"/{SAVE_LOG}")
|
58 |
+
async def save_log(message: str = Form(...), log_path: str = Form(...)):
|
59 |
+
"""
|
60 |
+
Save a log message to a specified log file on the server.
|
61 |
+
"""
|
62 |
+
# Ensure the directory for the log file exists
|
63 |
+
if os.path.dirname(log_path):
|
64 |
+
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
65 |
+
|
66 |
+
# Append the log message to the specified log file
|
67 |
+
async with aiofiles.open(log_path, mode='a') as f:
|
68 |
+
await f.write(f"{message}\n")
|
69 |
+
|
70 |
+
logger.info(f"Romote log message saved to {log_path}")
|
71 |
+
return {"message": f"Log message saved successfully to {log_path}"}
|
72 |
+
|
73 |
+
|
74 |
+
@app.get(f"/read_file")
|
75 |
+
async def read_file(file_name: str):
|
76 |
+
"""
|
77 |
+
Reads the content of a specified file and returns it.
|
78 |
+
"""
|
79 |
+
if not os.path.exists(file_name):
|
80 |
+
return {"message": f"File {file_name} does not exist."}
|
81 |
+
|
82 |
+
async with aiofiles.open(file_name, mode='r') as f:
|
83 |
+
content = await f.read()
|
84 |
+
|
85 |
+
logger.info(f"Read file {file_name}")
|
86 |
+
return {"file_name": file_name, "content": content}
|
serve/log_utils.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Common utilities.
|
3 |
+
"""
|
4 |
+
from asyncio import AbstractEventLoop
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import logging.handlers
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
import sys
|
11 |
+
from typing import AsyncGenerator, Generator
|
12 |
+
import warnings
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
import requests
|
16 |
+
|
17 |
+
from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
|
18 |
+
from .utils import save_log_str_on_log_server
|
19 |
+
|
20 |
+
|
21 |
+
handler = None
|
22 |
+
visited_loggers = set()
|
23 |
+
|
24 |
+
|
25 |
+
# Assuming LOGDIR and other necessary imports and global variables are defined
|
26 |
+
|
27 |
+
class APIHandler(logging.Handler):
|
28 |
+
"""Custom logging handler that sends logs to an API."""
|
29 |
+
|
30 |
+
def __init__(self, apiUrl, log_path, *args, **kwargs):
|
31 |
+
super(APIHandler, self).__init__(*args, **kwargs)
|
32 |
+
self.apiUrl = apiUrl
|
33 |
+
self.log_path = log_path
|
34 |
+
|
35 |
+
def emit(self, record):
|
36 |
+
log_entry = self.format(record)
|
37 |
+
try:
|
38 |
+
save_log_str_on_log_server(log_entry, self.log_path)
|
39 |
+
except requests.RequestException as e:
|
40 |
+
print(f"Error sending log to API: {e}", file=sys.stderr)
|
41 |
+
|
42 |
+
def build_logger(logger_name, logger_filename, add_remote_handler=False):
|
43 |
+
global handler
|
44 |
+
|
45 |
+
formatter = logging.Formatter(
|
46 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
47 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
48 |
+
)
|
49 |
+
|
50 |
+
# Set the format of root handlers
|
51 |
+
if not logging.getLogger().handlers:
|
52 |
+
if sys.version_info[1] >= 9:
|
53 |
+
# This is for windows
|
54 |
+
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
55 |
+
else:
|
56 |
+
if platform.system() == "Windows":
|
57 |
+
warnings.warn(
|
58 |
+
"If you are running on Windows, "
|
59 |
+
"we recommend you use Python >= 3.9 for UTF-8 encoding."
|
60 |
+
)
|
61 |
+
logging.basicConfig(level=logging.INFO)
|
62 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
63 |
+
|
64 |
+
# Redirect stdout and stderr to loggers
|
65 |
+
stdout_logger = logging.getLogger("stdout")
|
66 |
+
stdout_logger.setLevel(logging.INFO)
|
67 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
68 |
+
sys.stdout = sl
|
69 |
+
|
70 |
+
stderr_logger = logging.getLogger("stderr")
|
71 |
+
stderr_logger.setLevel(logging.ERROR)
|
72 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
73 |
+
sys.stderr = sl
|
74 |
+
|
75 |
+
# Get logger
|
76 |
+
logger = logging.getLogger(logger_name)
|
77 |
+
logger.setLevel(logging.INFO)
|
78 |
+
|
79 |
+
if add_remote_handler:
|
80 |
+
# Add APIHandler to send logs to your API
|
81 |
+
api_url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
|
82 |
+
|
83 |
+
remote_logger_filename = str(Path(logger_filename).stem + "_remote.log")
|
84 |
+
api_handler = APIHandler(apiUrl=api_url, log_path=f"{LOGDIR}/{remote_logger_filename}")
|
85 |
+
api_handler.setFormatter(formatter)
|
86 |
+
logger.addHandler(api_handler)
|
87 |
+
|
88 |
+
stdout_logger.addHandler(api_handler)
|
89 |
+
stderr_logger.addHandler(api_handler)
|
90 |
+
|
91 |
+
# if LOGDIR is empty, then don't try output log to local file
|
92 |
+
if LOGDIR != "":
|
93 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
94 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
95 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
96 |
+
filename, when="D", utc=True, encoding="utf-8"
|
97 |
+
)
|
98 |
+
handler.setFormatter(formatter)
|
99 |
+
|
100 |
+
for l in [stdout_logger, stderr_logger, logger]:
|
101 |
+
if l in visited_loggers:
|
102 |
+
continue
|
103 |
+
visited_loggers.add(l)
|
104 |
+
l.addHandler(handler)
|
105 |
+
|
106 |
+
return logger
|
107 |
+
|
108 |
+
|
109 |
+
class StreamToLogger(object):
|
110 |
+
"""
|
111 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(self, logger, log_level=logging.INFO):
|
115 |
+
self.terminal = sys.stdout
|
116 |
+
self.logger = logger
|
117 |
+
self.log_level = log_level
|
118 |
+
self.linebuf = ""
|
119 |
+
|
120 |
+
def __getattr__(self, attr):
|
121 |
+
return getattr(self.terminal, attr)
|
122 |
+
|
123 |
+
def write(self, buf):
|
124 |
+
temp_linebuf = self.linebuf + buf
|
125 |
+
self.linebuf = ""
|
126 |
+
for line in temp_linebuf.splitlines(True):
|
127 |
+
# From the io.TextIOWrapper docs:
|
128 |
+
# On output, if newline is None, any '\n' characters written
|
129 |
+
# are translated to the system default line separator.
|
130 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
131 |
+
# translates them so this is still cross platform.
|
132 |
+
if line[-1] == "\n":
|
133 |
+
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
134 |
+
self.logger.log(self.log_level, encoded_message.rstrip())
|
135 |
+
else:
|
136 |
+
self.linebuf += line
|
137 |
+
|
138 |
+
def flush(self):
|
139 |
+
if self.linebuf != "":
|
140 |
+
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
141 |
+
self.logger.log(self.log_level, encoded_message.rstrip())
|
142 |
+
self.linebuf = ""
|