inLine-XJY commited on
Commit
d276afe
1 Parent(s): e787d71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import pathlib
3
+ directory = pathlib.Path(os.getcwd())
4
+ print(directory)
5
+ sys.path.append(str(directory))
6
+ import torch
7
+ import numpy as np
8
+ from omegaconf import OmegaConf
9
+ from PIL import Image
10
+ from tqdm import tqdm, trange
11
+ from ldm.util import instantiate_from_config
12
+ from ldm.models.diffusion.scheduling_lcm import LCMSampler
13
+ from ldm.models.diffusion.plms import PLMSSampler
14
+ import pandas as pd
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+ from icecream import ic
18
+ from pathlib import Path
19
+ import soundfile as sf
20
+ import yaml
21
+ import datetime
22
+ from vocoder.bigvgan.models import VocoderBigVGAN
23
+ import soundfile
24
+ # from pytorch_memlab import LineProfiler,profile
25
+ import gradio
26
+
27
+ def load_model_from_config(config, ckpt = None, verbose=True):
28
+ model = instantiate_from_config(config.model)
29
+ if ckpt:
30
+ print(f"Loading model from {ckpt}")
31
+ pl_sd = torch.load(ckpt, map_location="cpu")
32
+ sd = pl_sd["state_dict"]
33
+
34
+ m, u = model.load_state_dict(sd, strict=False)
35
+ if len(m) > 0 and verbose:
36
+ print("missing keys:")
37
+ print(m)
38
+ if len(u) > 0 and verbose:
39
+ print("unexpected keys:")
40
+ print(u)
41
+ else:
42
+ print(f"Note chat no ckpt is loaded !!!")
43
+
44
+ model.cuda()
45
+ model.eval()
46
+ return model
47
+
48
+
49
+
50
+
51
+ class GenSamples:
52
+ def __init__(self,sampler,model,outpath,vocoder = None,save_mel = True,save_wav = True, original_inference_steps=None) -> None:
53
+ self.sampler = sampler
54
+ self.model = model
55
+ self.outpath = outpath
56
+ if save_wav:
57
+ assert vocoder is not None
58
+ self.vocoder = vocoder
59
+ self.save_mel = save_mel
60
+ self.save_wav = save_wav
61
+ self.channel_dim = self.model.channels
62
+ self.original_inference_steps = original_inference_steps
63
+
64
+ def gen_test_sample(self,prompt,mel_name = None,wav_name = None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'}
65
+ uc = None
66
+ record_dicts = []
67
+ # if os.path.exists(os.path.join(self.outpath,mel_name+f'_0.npy')):
68
+ # return record_dicts
69
+ emptycap = {'ori_caption':1*[""],'struct_caption':1*[""]}
70
+ uc = self.model.get_learned_conditioning(emptycap)
71
+
72
+ for n in range(1):# trange(self.opt.n_iter, desc="Sampling"):
73
+ for k,v in prompt.items():
74
+ prompt[k] = 1 * [v]
75
+ c = self.model.get_learned_conditioning(prompt)# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
76
+ if self.channel_dim>0:
77
+ shape = [self.channel_dim, 20, 312] # (z_dim, 80//2^x, 848//2^x)
78
+ else:
79
+ shape = [20, 312]
80
+ samples_ddim, _ = self.sampler.sample(S=2,
81
+ conditioning=c,
82
+ batch_size=1,
83
+ shape=shape,
84
+ verbose=False,
85
+ guidance_scale=5,
86
+ original_inference_steps=self.original_inference_steps
87
+ )
88
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
89
+ for idx,spec in enumerate(x_samples_ddim):
90
+ spec = spec.squeeze(0).cpu().numpy()
91
+ record_dict = {'caption':prompt['ori_caption'][0]}
92
+ if self.save_mel:
93
+ mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy')
94
+ np.save(mel_path,spec)
95
+ record_dict['mel_path'] = mel_path
96
+ if self.save_wav:
97
+ wav = self.vocoder.vocode(spec)
98
+ wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav')
99
+ soundfile.write(wav_path, wav, 16000)
100
+ record_dict['audio_path'] = wav_path
101
+ record_dicts.append(record_dict)
102
+ return record_dicts
103
+
104
+
105
+ def infer(ori_prompt):
106
+
107
+ prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>')
108
+
109
+
110
+ config = OmegaConf.load("configs/audiolcm.yaml")
111
+
112
+ # print("-------quick debug no load ckpt---------")
113
+ # model = instantiate_from_config(config['model'])# for quick debug
114
+ model = load_model_from_config(config, "./model/000184.ckpt")
115
+
116
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
117
+ model = model.to(device)
118
+
119
+ sampler = LCMSampler(model)
120
+
121
+ os.makedirs("results/test", exist_ok=True)
122
+
123
+ vocoder = VocoderBigVGAN("./vocoder/bigvnat16k93.5w",device)
124
+
125
+
126
+ generator = GenSamples(sampler,model,"results/test",vocoder,save_mel = False,save_wav = True, original_inference_steps=config.model.params.num_ddim_timesteps)
127
+ csv_dicts = []
128
+
129
+ with torch.no_grad():
130
+ with model.ema_scope():
131
+ wav_name = f'{prompt.strip().replace(" ", "-")}'
132
+ generator.gen_test_sample(prompt,wav_name=wav_name)
133
+
134
+ print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.")
135
+ return "results/test/"+wav_name+"_0.wav"
136
+
137
+ def my_inference_function(prompt_oir):
138
+ file_path = infer(prompt_oir)
139
+ return file_path
140
+
141
+
142
+
143
+ gradio_interface = gradio.Interface(
144
+ fn = my_inference_function,
145
+ inputs = "text",
146
+ outputs = "audio"
147
+ )
148
+ gradio_interface.launch()