lysandre's picture
lysandre HF staff
commit files to HF hub
83b7d03
raw
history blame
1.63 kB
import torch
from datasets import load_dataset
from transformers import Pipeline, SpeechT5Processor, SpeechT5HifiGan
class TTSPipeline(Pipeline):
def __init__(self, *args, vocoder=None, processor=None, **kwargs):
super().__init__(*args, **kwargs)
if vocoder is None:
raise ValueError("Must pass a vocoder to the TTSPipeline.")
if processor is None:
raise ValueError("Must pass a processor to the TTSPipeline.")
if isinstance(vocoder, str):
vocoder = SpeechT5HifiGan.from_pretrained(vocoder)
if isinstance(processor, str):
processor = SpeechT5Processor.from_pretrained(processor)
self.processor = processor
self.vocoder = vocoder
def preprocess(self, text, speaker_embeddings=None):
inputs = self.processor(text=text, return_tensors='pt')
if speaker_embeddings is None:
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
return {'inputs': inputs, 'speaker_embeddings': speaker_embeddings}
def _forward(self, model_inputs):
inputs = model_inputs['inputs']
speaker_embeddings = model_inputs['speaker_embeddings']
with torch.no_grad():
speech = self.model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=self.vocoder)
return speech
def _sanitize_parameters(self, **pipeline_parameters):
return {}, {}, {}
def postprocess(self, speech):
return speech