Text-to-Audio
Inference Endpoints
hungchiayu commited on
Commit
2b70cbe
·
verified ·
1 Parent(s): 881a418

Create tangoflux

Browse files
Files changed (1) hide show
  1. tangoflux +61 -0
tangoflux ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderOobleck
2
+ import torch
3
+ from transformers import T5EncoderModel,T5TokenizerFast
4
+ from diffusers import FluxTransformer2DModel
5
+ from torch import nn
6
+ from typing import List
7
+ from diffusers import FlowMatchEulerDiscreteScheduler
8
+ from diffusers.training_utils import compute_density_for_timestep_sampling
9
+ import copy
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ from src.model import TangoFlux
13
+ from huggingface_hub import snapshot_download
14
+ from tqdm import tqdm
15
+ from typing import Optional,Union,List
16
+ from datasets import load_dataset, Audio
17
+ from math import pi
18
+ import json
19
+ import inspect
20
+ import yaml
21
+ from safetensors.torch import load_file
22
+
23
+
24
+ class TangoFluxInference:
25
+
26
+ def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
27
+
28
+
29
+ self.vae = AutoencoderOobleck()
30
+
31
+ paths = snapshot_download(repo_id=name)
32
+ vae_weights = load_file("{}/vae.safetensors".format(paths))
33
+ self.vae.load_state_dict(vae_weights)
34
+ weights = load_file("{}/tangoflux.safetensors".format(paths))
35
+
36
+ with open('{}/config.json'.format(paths),'r') as f:
37
+ config = json.load(f)
38
+ self.model = TangoFlux(config)
39
+ self.model.load_state_dict(weights,strict=False)
40
+ # _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
41
+ self.vae.to(device)
42
+ self.model.to(device)
43
+
44
+ def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
45
+
46
+ with torch.no_grad():
47
+ latents = self.model.inference_flow(prompt,
48
+ duration=duration,
49
+ num_inference_steps=steps,
50
+ guidance_scale=guidance_scale)
51
+
52
+
53
+
54
+ wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
55
+ waveform_end = int(duration * self.vae.config.sampling_rate)
56
+ wave = wave[:, :waveform_end]
57
+ return wave
58
+
59
+
60
+
61
+