File size: 4,498 Bytes
7c1eee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import concurrent.futures 
import random
import gradio as gr
import requests
import io, base64, json
# import spaces
from PIL import Image

from .model_config import model_config
from .model_worker import BaseModelWorker

class ModelManager:
    def __init__(self):
        self.models_config = model_config
        self.models_worker: list[BaseModelWorker] = {}

        self.build_model_workers()

    def build_model_workers(self):
        for cfg in self.models_config.values():
            worker = BaseModelWorker(cfg.model_name, cfg.i2s_model, cfg.online_model, cfg.model_path)
            self.models_worker[cfg.model_name] = worker
    
    def get_all_models(self):
        models = []
        for model_name in self.models_config.keys():
            models.append(model_name)
        return models
    
    def get_t2s_models(self):
        models = []
        for cfg in self.models_config.values():
            if not cfg.i2s_model:
                models.append(cfg.model_name)
        return models
    
    def get_i2s_models(self):
        models = []
        for cfg in self.models_config.values():
            if cfg.i2s_model:
                models.append(cfg.model_name)
        return models

    def get_online_models(self):
        models = []
        for cfg in self.models_config.values():
            if cfg.online_model:
                models.append(cfg.model_name)
        return models

    def get_models(self, i2s_model:bool, online_model:bool):
        models = []
        for cfg in self.models_config.values():
            if cfg.i2s_model==i2s_model and cfg.online_model==online_model:
                models.append(cfg.model_name)
        return models
    
    def check_online(self, name):
        worker = self.models_worker[name]
        if not worker.online_model:
            return 

    # @spaces.GPU(duration=120)
    def inference(self, prompt, model_name):
        worker = self.models_worker[model_name]
        result = worker.inference(prompt=prompt)
        return result
    
    def render(self, prompt, model_name):
        worker = self.models_worker[model_name]
        result = worker.render(prompt=prompt)
        return result
    
    def inference_parallel(self, prompt, model_A, model_B):
        results = []
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_result = {executor.submit(self.inference, prompt, model): model 
                                for model in model_names}
            for future in concurrent.futures.as_completed(future_to_result):
                result = future.result()
                results.append(result)
        return results[0], results[1]

    def inference_parallel_anony(self, prompt, model_A, model_B, i2s_model):
        if model_A == model_B == "":
            model_A, model_B = random.sample(self.get_models(i2s_model=i2s_model, online_model=True), 2)
        model_names = [model_A, model_B]
        results = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_result = {executor.submit(self.inference, prompt, model): model 
                                for model in model_names}
            for future in concurrent.futures.as_completed(future_to_result):
                result = future.result()
                results.append(result)
        return results[0], results[1]
    
    
    def render_parallel(self, prompt, model_A, model_B):
        results = []
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_result = {executor.submit(self.render, prompt, model): model 
                                for model in model_names}
            for future in concurrent.futures.as_completed(future_to_result):
                result = future.result()
                results.append(result)
        return results[0], results[1]
    
    # def i2s_inference_parallel(self, image, model_A, model_B):
    #     results = []
    #     model_names = [model_A, model_B]
    #     with concurrent.futures.ThreadPoolExecutor() as executor:
    #         future_to_result = {executor.submit(self.inference, image, model): model 
    #                             for model in model_names}
    #         for future in concurrent.futures.as_completed(future_to_result):
    #             result = future.result()
    #             results.append(result)
    #     return results[0], results[1]