File size: 3,947 Bytes
b6dc501
 
 
 
 
 
 
 
 
 
 
a4b32da
b6dc501
 
 
a4b32da
b6dc501
 
 
 
 
 
 
 
 
 
 
a4b32da
 
b6dc501
 
 
 
 
 
 
a4b32da
b6dc501
 
 
 
a4b32da
 
b6dc501
 
 
 
 
a4b32da
b6dc501
 
 
 
a4b32da
b6dc501
 
 
a4b32da
b6dc501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b32da
b6dc501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import json
from trueskill import TrueSkill
import paramiko
import io, os
import sys
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_VIDEO_SKILL
trueskill_env = TrueSkill()
sys.path.append('../')
from model.models import VIDEO_GENERATION_MODELS


ssh_skill_client = None
sftp_skill_client = None


def create_ssh_skill_client(server, port, user, password):
    global ssh_skill_client, sftp_skill_client
    ssh_skill_client = paramiko.SSHClient()
    ssh_skill_client.load_system_host_keys()
    ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh_skill_client.connect(server, port, user, password)

    transport = ssh_skill_client.get_transport()
    transport.set_keepalive(60)

    sftp_skill_client = ssh_skill_client.open_sftp()


def is_connected():
    global ssh_skill_client, sftp_skill_client
    if ssh_skill_client is None or sftp_skill_client is None:
        return False
    if not ssh_skill_client.get_transport().is_active():
        return False
    try:
        sftp_skill_client.listdir('.')
    except Exception as e:
        print(f"Error checking SFTP connection: {e}")
        return False
    return True


def ucb_score(trueskill_diff, t, n):
    exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
    ucb = -trueskill_diff + 1.0 * exploration_term
    return ucb


def update_trueskill(ratings, ranks):
    new_ratings = trueskill_env.rate(ratings, ranks)
    return new_ratings


def serialize_rating(rating):
    return {'mu': rating.mu, 'sigma': rating.sigma}


def deserialize_rating(rating_dict):
    return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])


def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
    global sftp_skill_client
    if not is_connected():
        create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
    data = {
        'ratings': [serialize_rating(r) for r in ratings],
        'comparison_counts': comparison_counts.tolist(),
        'total_comparisons': total_comparisons
    }  
    json_data = json.dumps(data)
    with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
        f.write(json_data)


def load_json_via_sftp():
    global sftp_skill_client
    if not is_connected():
        create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
    with sftp_skill_client.open(SSH_VIDEO_SKILL, 'r') as f:
        data = json.load(f)
    ratings = [deserialize_rating(r) for r in data['ratings']]
    comparison_counts = np.array(data['comparison_counts'])
    total_comparisons = data['total_comparisons']
    return ratings, comparison_counts, total_comparisons


def update_skill_video(rank, model_names, k_group=4):

    ratings, comparison_counts, total_comparisons = load_json_via_sftp()

    # group = Model_ID.group
    group = []
    for model_name in model_names:
        group.append(VIDEO_GENERATION_MODELS.index(model_name))
    print(group)

    pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))]
    for player1, player2 in pairwise_comparisons:
        if rank[player1] < rank[player2]:
            ranks = [0, 1]
            updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
            ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
        elif rank[player1] > rank[player2]:
            ranks = [1, 0]
            updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
            ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
            
        comparison_counts[group[player1], group[player2]] += 1
        comparison_counts[group[player2], group[player1]] += 1
    
    total_comparisons += 1

    save_json_via_sftp(ratings, comparison_counts, total_comparisons)