RGMC / infer.py
NikitaSrivatsan
Added local version of CLAP checkpoint
01eda25
raw
history blame
2.24 kB
from audiocaptioner import AudioCaptioner
from data_module import AudiostockDataset
from utils import *
def infer(input_filename):
device = get_device(0)
# connect to GCS
gcs = CheckpointManager()
# create and/or load model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False)
prefix_dim = 512
prefix_length = 10
prefix_length_clip = 10
num_layers = 8
checkpoint = 'checkpoints/ZRIUE-BEST.pt'
model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device)
model.load_state_dict(gcs.get_checkpoint(checkpoint))
print(f'Loaded from {checkpoint}')
model.eval()
# read in the wav file and precompute neighbors
dataset_path = ''
train_dataset = AudiostockDataset(
dataset_path=dataset_path,
train=False,
split='audiostock-train-240k.txt',
factor=1.0,
verbose=False,
file_list=open('audiostock-train-240k.txt', 'r').read().split()
)
print('Reading in file', input_filename)
dataset = AudiostockDataset(
dataset_path=dataset_path,
train=False,
split=None,
factor=1.0,
verbose=False,
file_list=[input_filename] # manually override file list
)
dataset.precompute_neighbors(model, candidate_set=train_dataset)
waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32)
# predict
with torch.no_grad():
prefix_embed = model.create_prefix(waveform, 1)
tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150]
tweet_embed = model.gpt.transformer.wte(tweet_tokens)
prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1)
candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5)
generated_text = candidates[0]
generated_text = postproc(generated_text)
print('=======================================')
print(generated_text)
if __name__ == '__main__':
infer('../MusicCaptioning/sample_inputs/sisters.mp3')