Spaces:
Sleeping
Sleeping
from audiocaptioner import AudioCaptioner | |
from data_module import AudiostockDataset | |
from utils import * | |
def infer(input_filename): | |
device = get_device(0) | |
# connect to GCS | |
gcs = CheckpointManager() | |
# create and/or load model | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False) | |
prefix_dim = 512 | |
prefix_length = 10 | |
prefix_length_clip = 10 | |
num_layers = 8 | |
checkpoint = 'checkpoints/ZRIUE-BEST.pt' | |
model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device) | |
model.load_state_dict(gcs.get_checkpoint(checkpoint)) | |
print(f'Loaded from {checkpoint}') | |
model.eval() | |
# read in the wav file and precompute neighbors | |
dataset_path = '' | |
train_dataset = AudiostockDataset( | |
dataset_path=dataset_path, | |
train=False, | |
split='audiostock-train-240k.txt', | |
factor=1.0, | |
verbose=False, | |
file_list=open('audiostock-train-240k.txt', 'r').read().split() | |
) | |
print('Reading in file', input_filename) | |
dataset = AudiostockDataset( | |
dataset_path=dataset_path, | |
train=False, | |
split=None, | |
factor=1.0, | |
verbose=False, | |
file_list=[input_filename] # manually override file list | |
) | |
dataset.precompute_neighbors(model, candidate_set=train_dataset) | |
waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32) | |
# predict | |
with torch.no_grad(): | |
prefix_embed = model.create_prefix(waveform, 1) | |
tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150] | |
tweet_embed = model.gpt.transformer.wte(tweet_tokens) | |
prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1) | |
candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5) | |
generated_text = candidates[0] | |
generated_text = postproc(generated_text) | |
print('=======================================') | |
print(generated_text) | |
if __name__ == '__main__': | |
infer('../MusicCaptioning/sample_inputs/sisters.mp3') | |