|
import torch |
|
from torch import nn |
|
from transformers import BartTokenizer |
|
from modeling_bart import BartForMultimodalGeneration |
|
from music_encoder import CNNSA |
|
|
|
|
|
|
|
class CommentGenerator_fusion(nn.Module): |
|
def __init__(self): |
|
super(CommentGenerator_fusion, self).__init__() |
|
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") |
|
model_path = "best_model.pth" |
|
self.music_encoder = CNNSA().cuda() |
|
self.music_encoder.load_state_dict(torch.load(model_path)) |
|
|
|
for params in self.music_encoder.parameters(): |
|
params.requires_grad = False |
|
|
|
self.bart = BartForMultimodalGeneration.from_pretrained("facebook/bart-base", |
|
fusion_layers=[4,5], |
|
use_forget_gate=False, |
|
dim_common=768, |
|
n_attn_heads=1).cuda() |
|
|
|
|
|
def forward(self, input_sentence_list, music_ids, labels=None): |
|
encoded_input = self.tokenizer( |
|
input_sentence_list, |
|
padding=True, |
|
truncation=True, |
|
max_length=512, |
|
return_tensors='pt', |
|
) |
|
if labels is not None: |
|
labels = self.tokenizer( |
|
labels, |
|
padding=True, |
|
truncation=True, |
|
max_length=512, |
|
return_tensors='pt', |
|
) |
|
music_features = self.music_encoder(music_ids) |
|
output = self.bart(input_ids=encoded_input['input_ids'].cuda(), |
|
attention_mask=encoded_input['attention_mask'].cuda(), |
|
labels=labels['input_ids'].cuda(), |
|
music_features=music_features |
|
|
|
) |
|
return output |
|
|
|
def generate(self, input_sentence_list, music_ids, is_cuda=True): |
|
encoded_input = self.tokenizer(input_sentence_list, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt', |
|
) |
|
music_features = self.music_encoder(music_ids) |
|
output_ids = self.bart.generate(encoded_input['input_ids'].cuda(), |
|
num_beams=5, |
|
max_length=512, |
|
early_stopping=True, |
|
do_sample=True, |
|
music_features=music_features) |
|
return ([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
for g in output_ids]) |
|
|
|
|
|
|