hungchiayu
commited on
Create tangoflux
Browse files
tangoflux
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import AutoencoderOobleck
|
2 |
+
import torch
|
3 |
+
from transformers import T5EncoderModel,T5TokenizerFast
|
4 |
+
from diffusers import FluxTransformer2DModel
|
5 |
+
from torch import nn
|
6 |
+
from typing import List
|
7 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
8 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling
|
9 |
+
import copy
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import numpy as np
|
12 |
+
from src.model import TangoFlux
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
+
from tqdm import tqdm
|
15 |
+
from typing import Optional,Union,List
|
16 |
+
from datasets import load_dataset, Audio
|
17 |
+
from math import pi
|
18 |
+
import json
|
19 |
+
import inspect
|
20 |
+
import yaml
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
|
24 |
+
class TangoFluxInference:
|
25 |
+
|
26 |
+
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
|
27 |
+
|
28 |
+
|
29 |
+
self.vae = AutoencoderOobleck()
|
30 |
+
|
31 |
+
paths = snapshot_download(repo_id=name)
|
32 |
+
vae_weights = load_file("{}/vae.safetensors".format(paths))
|
33 |
+
self.vae.load_state_dict(vae_weights)
|
34 |
+
weights = load_file("{}/tangoflux.safetensors".format(paths))
|
35 |
+
|
36 |
+
with open('{}/config.json'.format(paths),'r') as f:
|
37 |
+
config = json.load(f)
|
38 |
+
self.model = TangoFlux(config)
|
39 |
+
self.model.load_state_dict(weights,strict=False)
|
40 |
+
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
|
41 |
+
self.vae.to(device)
|
42 |
+
self.model.to(device)
|
43 |
+
|
44 |
+
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
latents = self.model.inference_flow(prompt,
|
48 |
+
duration=duration,
|
49 |
+
num_inference_steps=steps,
|
50 |
+
guidance_scale=guidance_scale)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
|
55 |
+
waveform_end = int(duration * self.vae.config.sampling_rate)
|
56 |
+
wave = wave[:, :waveform_end]
|
57 |
+
return wave
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|