File size: 7,306 Bytes
42a4c72
 
 
 
 
 
 
 
 
 
 
 
 
04be048
 
 
42a4c72
 
 
 
04be048
 
42a4c72
 
04be048
42a4c72
 
 
 
 
 
 
 
 
04be048
42a4c72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04be048
42a4c72
 
 
 
 
 
084bb16
42a4c72
084bb16
04be048
42a4c72
 
 
 
084bb16
42a4c72
084bb16
04be048
 
42a4c72
 
 
 
 
 
 
 
 
 
 
 
 
 
084bb16
04be048
42a4c72
084bb16
04be048
084bb16
04be048
42a4c72
084bb16
42a4c72
 
 
084bb16
 
42a4c72
 
084bb16
42a4c72
 
 
084bb16
42a4c72
 
 
 
084bb16
42a4c72
 
 
084bb16
42a4c72
 
04be048
42a4c72
 
 
 
 
 
 
 
 
084bb16
 
 
42a4c72
04be048
42a4c72
04be048
42a4c72
 
084bb16
42a4c72
 
 
084bb16
42a4c72
 
 
084bb16
42a4c72
 
 
 
084bb16
 
04be048
42a4c72
04be048
084bb16
 
42a4c72
 
4b317fc
42a4c72
 
 
 
 
 
084bb16
04be048
 
 
 
 
084bb16
42a4c72
 
 
 
04be048
42a4c72
 
 
 
04be048
42a4c72
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
import torch
import librosa
import gradio as gr
from scipy.io.wavfile import write
from huggingface_hub import hf_hub_download, snapshot_download
import utils
from models import SynthesizerTrn
from mel_processing import mel_spectrogram_torch
from speaker_encoder.voice_encoder import SpeakerEncoder
import logging
from transformers import Wav2Vec2FeatureExtractor, HubertModel

# 设置日志级别和格式
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 模型配置
MODEL_CONFIG = {
    "freevc": {
        "repo_id": "guetLzy/Chinese-FreeVC-Model",
        "files": ["G_17000.pth", "G_35000.pth"]
    },
    "hubert": {
        "repo_id": "guetLzy/chinese-hubert-large-fariseq-ckpt",
    }
}

# 设备设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 可用的模型选项
MODEL_OPTIONS = {
    "Model_17000": "model/G_17000.pth",
    "Model_35000": "model/G_35000.pth",
}

def download_models():
    """下载所有需要的模型文件"""
    os.makedirs("model", exist_ok=True)
    os.makedirs("hubert/chinese-hubert-large-fairseq-ckpt", exist_ok=True)

    freevc_paths = {}
    for model_name, model_path in MODEL_OPTIONS.items():
        path = hf_hub_download(
            repo_id=MODEL_CONFIG["freevc"]["repo_id"],
            filename=os.path.basename(model_path),
            local_dir="model",
            resume_download=True
        )
        freevc_paths[model_name] = path

    hubert_dir = "hubert/chinese-hubert-large-fairseq-ckpt"
    snapshot_download(
        repo_id=MODEL_CONFIG["hubert"]["repo_id"],
        local_dir=hubert_dir,
        repo_type="model",
        resume_download=True
    )
    hubert_paths = {"snapshot": hubert_dir}

    return {
        "freevc": freevc_paths,
        "hubert": hubert_paths
    }

def load_hubert(hubert_dir, status_list):
    """加载HuBERT模型(使用fairseq格式的检查点)"""
    status_list.append("正在加载 HuBERT 模型...")
    logger.info("正在加载 HuBERT 模型...")
    model = HubertModel.from_pretrained(hubert_dir)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_dir)
    return model.to(device).float().eval(), feature_extractor

def load_freevc(model_path, status_list):
    """加载FreeVC模型(使用本地配置文件)"""
    status_list.append(f"正在从 {model_path} 加载 FreeVC 模型...")
    logger.info(f"正在从 {model_path} 加载 FreeVC 模型...")
    hps = utils.get_hparams_from_file("configs/freevc.json")

    net_g = SynthesizerTrn(
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        **hps.model
    ).to(device)

    utils.load_checkpoint(model_path, net_g, None, True)
    net_g.eval()

    smodel = SpeakerEncoder("speaker_encoder/ckpt/pretrained_bak_5805000.pt") if hps.model.use_spk else None
    return net_g, smodel, hps

# 预加载模型
status_list = ["正在下载模型..."]
logger.info("正在下载模型...")
model_paths = download_models()
status_list.append(f"模型路径: {model_paths}")
logger.info(f"模型路径: {model_paths}")
status_list.append("正在初始化 HuBERT...")
logger.info("正在初始化 HuBERT...")
hubert_dir = "hubert/chinese-hubert-large-fairseq-ckpt"
hubert_model, hubert_feature_extractor = load_hubert(hubert_dir, status_list)

def voice_conversion(src_audio, tgt_audio, output_name, model_selection):
    """执行语音转换"""
    status_list = ["开始语音转换..."]
    
    try:
        # 加载选中的FreeVC模型
        freevc_model, speaker_model, hps = load_freevc(MODEL_OPTIONS[model_selection], status_list)

        with torch.no_grad():
            # 处理目标音频
            status_list.append("处理目标音频...")
            wav_tgt, _ = librosa.load(tgt_audio, sr=hps.data.sampling_rate)
            wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
            
            if hps.model.use_spk:
                status_list.append("提取目标音色特征(使用说话人编码器)...")
                g_tgt = speaker_model.embed_utterance(wav_tgt)
                g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
            else:
                status_list.append("生成目标音频 Mel 频谱图...")
                wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
                mel_tgt = mel_spectrogram_torch(
                    wav_tgt,
                    hps.data.filter_length,
                    hps.data.n_mel_channels,
                    hps.data.sampling_rate,
                    hps.data.hop_length,
                    hps.data.win_length,
                    hps.data.mel_fmin,
                    hps.data.mel_fmax
                )

            # 处理源音频
            status_list.append("处理源音频(转换为16kHz)...")
            wav_src, _ = librosa.load(src_audio, sr=16_000)
            inputs = hubert_feature_extractor(
                wav_src,
                return_tensors="pt",
                sampling_rate=16_000
            ).input_values.to(device)
            
            status_list.append("提取源音频特征...")
            c = hubert_model(inputs.float()).last_hidden_state.transpose(1, 2)

            # 执行转换
            status_list.append("执行语音转换...")
            audio = freevc_model.infer(c, g=g_tgt) if hps.model.use_spk else freevc_model.infer(c, mel=mel_tgt)
            
            # 保存结果
            status_list.append("保存转换结果...")
            os.makedirs("output", exist_ok=True)
            output_path = f"output/{output_name}.wav"
            write(output_path, hps.data.sampling_rate, audio[0][0].data.cpu().float().numpy())
            
            status_list.append("转换完成")
            return output_path, "\n".join(status_list)

    except Exception as e:
        logger.error(f"转换错误: {str(e)}")
        status_list.append(f"转换失败: {str(e)}")
        return None, "\n".join(status_list)

# Gradio界面
with gr.Blocks(title="Chinese-FreeVC 语音转换" ,theme='NoCrypt/miku') as app:
    gr.Markdown("## Chinese-FreeVC 语音转换系统")

    with gr.Row():
        with gr.Column():
            src_input = gr.Audio(label="源语音", type="filepath")
            tgt_input = gr.Audio(label="目标音色", type="filepath")
            with gr.Row():  # 输出文件名和模型选择在同一排
                model_dropdown = gr.Dropdown(
                    choices=list(MODEL_OPTIONS.keys()),
                    label="选择模型",
                    value="Model_17000"
                )
                output_name = gr.Textbox(label="输出文件名", value="converted")
            convert_btn = gr.Button("开始转换", variant="primary")
        
        with gr.Column():
            output_audio = gr.Audio(label="转换结果", interactive=False)
            status = gr.Textbox(label="状态", value="待机", interactive=False)

    convert_btn.click(
        fn=voice_conversion,
        inputs=[src_input, tgt_input, output_name, model_dropdown],
        outputs=[output_audio, status],
        api_name="convert"
    )

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)