|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
from kgs_binding.relation_mapper_builder import RelationsMapperBuilder |
|
from kgs_binding.kg_qa_binding_utils import load_kg_handler |
|
from data.relation_utils import clean_relations |
|
from model_utils import create_layers_head_mask |
|
|
|
from transformers import ( |
|
BartForConditionalGeneration, |
|
BartTokenizer, |
|
BartConfig, |
|
DisjunctiveConstraint, |
|
) |
|
|
|
from utils import get_jump_chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from custom_tokenizer import BartCustomTokenizerFast |
|
from custom_bart import BartCustomConfig, BartCustomForConditionalGeneration |
|
from utils import get_device, KGType, Model_Type |
|
|
|
from kgs_binding.kg_base_wrapper import KGBaseHandler |
|
from kgs_binding.swow_handler import SwowHandler |
|
from kgs_binding.conceptnet_handler import ConceptNetHandler |
|
|
|
class Inference: |
|
def __init__(self, model_path:str, max_length=32): |
|
self.device = get_device() |
|
self.tokenizer = self.prepare_tokenizer() |
|
self.model = self.prepare_model(model_path) |
|
self.max_length = max_length |
|
|
|
def prepare_tokenizer(self): |
|
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
|
return tokenizer |
|
|
|
def prepare_model(self, model_path): |
|
config = BartConfig.from_pretrained(model_path) |
|
model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(self.device) |
|
model.eval() |
|
return model |
|
|
|
def pre_process_context(self, context): |
|
context = context.lower() |
|
context_tokenized = self.tokenizer(context, padding='max_length', |
|
truncation='longest_first', max_length=self.max_length, |
|
return_tensors="pt", |
|
) |
|
return context_tokenized |
|
|
|
def generate_based_on_context(self, context): |
|
model_input = self.pre_process_context(context) |
|
generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device), |
|
attention_mask=model_input["attention_mask"].to(self.device), |
|
min_length=1, |
|
max_length=self.max_length, |
|
do_sample=True, |
|
early_stopping=True, |
|
num_beams=4, |
|
temperature=1.0, |
|
top_k=None, |
|
top_p=None, |
|
|
|
no_repeat_ngram_size=2, |
|
num_return_sequences=1, |
|
return_dict_in_generate=True, |
|
output_attentions=True, |
|
output_scores=True) |
|
|
|
response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
encoder_attentions = generated_answers_encoded['encoder_attentions'] |
|
return response, encoder_attentions, model_input |
|
|
|
def prepare_context_for_visualization(self, context): |
|
examples = [] |
|
response, encoder_outputs, model_input = self.generate_based_on_context(context) |
|
encoder_outputs = torch.stack(encoder_outputs) |
|
n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size() |
|
print(encoder_outputs.size()) |
|
encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt) |
|
for i, ex in enumerate(encoder_attentions): |
|
d = {} |
|
indices = model_input['input_ids'][i].detach().cpu() |
|
all_tokens = self.tokenizer.convert_ids_to_tokens(indices) |
|
useful_indeces = indices != self.tokenizer.pad_token_id |
|
all_tokens = np.array(all_tokens)[useful_indeces] |
|
all_tokens = [tok.replace('Ġ', '') for tok in all_tokens] |
|
d['words'] = all_tokens |
|
d['attentions'] = ex.detach().cpu().numpy() |
|
examples.append(d) |
|
print(d['words']) |
|
return response, examples |
|
|
|
class RelationsInference: |
|
def __init__(self, model_path:str, kg_type: KGType, model_type:Model_Type, max_length=32): |
|
self.device = get_device() |
|
kg_handler: KGBaseHandler = load_kg_handler(kg_type) |
|
self.kg_handler = kg_handler |
|
relation_names = kg_handler.get_relation_types() |
|
self.tokenizer = self.prepare_tokenizer(relation_names, model_type) |
|
self.simple_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
|
self.model, self.config = self.prepare_model(relation_names, model_path, model_type) |
|
self.relation_mapper_builder = RelationsMapperBuilder(knowledge=kg_handler) |
|
self.max_length = max_length |
|
|
|
def prepare_tokenizer(self, relation_names: List[str], model_type:Model_Type): |
|
tokenizer = BartCustomTokenizerFast.from_pretrained('facebook/bart-large') |
|
tokenizer.set_known_relation_names(relation_names) |
|
tokenizer.set_operation_mode(there_is_difference_between_relations=model_type.there_is_difference_between_relations()) |
|
return tokenizer |
|
|
|
def prepare_model(self, relation_names: List[str], model_path, model_type:Model_Type): |
|
config = BartCustomConfig.from_pretrained(model_path, revision='master') |
|
print('config.heads_mask:', config.heads_mask) |
|
if config.num_relation_kinds is None: |
|
config.num_relation_kinds = len(relation_names) |
|
if config.is_simple_mask_commonsense is None: |
|
config.is_simple_mask_commonsense = model_type.is_simple_mask_commonsense() |
|
if config.heads_mask is None: |
|
config.heads_mask = create_layers_head_mask(config) |
|
model = BartCustomForConditionalGeneration.from_pretrained(model_path, config=config, revision='master').to(self.device) |
|
model.eval() |
|
return model, config |
|
|
|
def pre_process_context(self, context): |
|
context = context.lower() |
|
|
|
commonsense_relations = self.relation_mapper_builder.get_relations_mapping_complex(context=[context], clear_common_wds=True) |
|
|
|
commonsense_relation = clean_relations(commonsense_relations)[0] |
|
|
|
print(commonsense_relation) |
|
context_tokenized = self.tokenizer(context, padding='max_length', |
|
truncation='longest_first', max_length=self.max_length, |
|
return_tensors="pt", return_offsets_mapping=True, |
|
input_commonsense_relations=commonsense_relation, |
|
) |
|
return context_tokenized |
|
|
|
def get_relations_information(self, phrase_generated): |
|
all_concepts = self.relation_mapper_builder.get_kg_concepts_from_context([phrase_generated], clear_common_wds=True)[0] |
|
words = phrase_generated.strip().split(' ') |
|
concepts_with_relations = self.relation_mapper_builder.get_concepts_from_context(phrase_generated, clear_common_wds=True) |
|
concepts_with_no_relations = list(set(all_concepts).difference(concepts_with_relations)) |
|
|
|
print("====== RELATIONS SUMMARY ======") |
|
print('phrase_generated:', phrase_generated) |
|
print('words:', words) |
|
print('all_concepts:', all_concepts) |
|
print('concepts_with_relations:', concepts_with_relations) |
|
print('without_relations:', concepts_with_no_relations) |
|
print("\n== STATS:") |
|
print('n_words:', len(words)) |
|
print('n_concepts:', len(all_concepts)) |
|
print('n_concepts_with_relations:', len(concepts_with_relations)) |
|
print('n_c_without_relations:', len(concepts_with_no_relations)) |
|
print("====== ================= ======") |
|
return words, all_concepts, concepts_with_relations, concepts_with_no_relations |
|
|
|
def remove_subsets(self, l): |
|
l2 = l[:] |
|
for m in l: |
|
for n in l: |
|
if set(m).issubset(set(n)) and m != n: |
|
l2.remove(m) |
|
break |
|
return l2 |
|
|
|
def generate_based_on_context(self, context, use_kg=False): |
|
model_input = self.pre_process_context(context) |
|
|
|
gen_kwargs = {} |
|
if "input_commonsense_relations" in model_input: |
|
|
|
gen_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(self.device) |
|
|
|
constraints = None |
|
if use_kg: |
|
constraints = [] |
|
concepts_from_context = self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) |
|
useful_concepts = [self.relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] |
|
if not useful_concepts: |
|
useful_concepts = [self.kg_handler.get_related_concepts(concept) for concept in concepts_from_context] |
|
useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] |
|
|
|
|
|
|
|
if concepts_from_context: |
|
for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): |
|
print('neighbour:', neighbour_concepts[:20]) |
|
|
|
|
|
flexible_words = [word for word in neighbour_concepts if word not in context_concept] |
|
flexible_words_ids: List[List[int]] = self.simple_tokenizer(flexible_words, add_prefix_space=True,add_special_tokens=False).input_ids |
|
flexible_words_ids = self.remove_subsets(flexible_words_ids) |
|
|
|
|
|
flexible_words_ids = flexible_words_ids[:10] |
|
print('flexible_words_ids:', flexible_words_ids[:3]) |
|
constraint = DisjunctiveConstraint(flexible_words_ids) |
|
constraints.append(constraint) |
|
else: |
|
constraints = None |
|
|
|
generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device), |
|
attention_mask=model_input["attention_mask"].to(self.device), |
|
constraints=constraints, |
|
min_length=1, |
|
max_length=self.max_length, |
|
do_sample=False, |
|
early_stopping=True, |
|
num_beams=8, |
|
temperature=1.0, |
|
top_k=None, |
|
top_p=None, |
|
|
|
no_repeat_ngram_size=2, |
|
num_return_sequences=1, |
|
return_dict_in_generate=True, |
|
output_attentions=True, |
|
output_scores=True, |
|
**gen_kwargs, |
|
) |
|
|
|
response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
encoder_attentions = generated_answers_encoded['encoder_attentions'] |
|
return response, encoder_attentions, model_input |
|
|
|
def get_related_concepts_list(self, knowledge, list_concepts): |
|
other_concepts = [] |
|
for concept in list_concepts: |
|
other_near_concepts = knowledge.get_related_concepts(concept) |
|
other_concepts.extend(other_near_concepts) |
|
return other_concepts |
|
|
|
|
|
def generate_contrained_based_on_context(self, contexts, use_kg=True, max_concepts=1): |
|
model_inputs = [self.pre_process_context(context) for context in contexts] |
|
constraints = None |
|
if use_kg: |
|
constraints = [] |
|
concepts_from_contexts = [self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) for context in contexts] |
|
neighbours_contexts = [] |
|
if not neighbours_contexts: |
|
neighbours_contexts = [self.get_related_concepts_list(self.kg_handler, context) for context in concepts_from_contexts] |
|
all_constraints = [] |
|
for context_neighbours in neighbours_contexts: |
|
|
|
|
|
context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] |
|
n_size_chuncks = len(context_neighbours) // max_concepts |
|
n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 |
|
sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) |
|
constraints = [] |
|
for sub_concepts in sub_concepts_collection[:max_concepts]: |
|
flexible_words_ids: List[List[int]] = self.tokenizer(sub_concepts, |
|
add_special_tokens=False).input_ids |
|
|
|
flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] |
|
disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) |
|
if not any(disjunctive_set): |
|
continue |
|
constraint = DisjunctiveConstraint(disjunctive_set) |
|
constraints.append(constraint) |
|
if not any(constraints): |
|
constraints = None |
|
all_constraints.append(constraints) |
|
else: |
|
all_constraints = None |
|
if not all_constraints: |
|
all_constraints = None |
|
|
|
generated_answers_encoded = [] |
|
encoder_attentions_list = [] |
|
for i, contraints in enumerate(all_constraints): |
|
|
|
gen_kwargs = {} |
|
inputs = model_inputs[i] |
|
if "input_commonsense_relations" in inputs: |
|
|
|
gen_kwargs["relation_inputs"] = inputs.get("input_commonsense_relations").to(self.device) |
|
|
|
gen = self.model.generate(input_ids=inputs["input_ids"].to(self.device), |
|
attention_mask=inputs["attention_mask"].to(self.device), |
|
constraints=constraints, |
|
min_length=1, |
|
max_length=self.max_length, |
|
do_sample=False, |
|
early_stopping=True, |
|
num_beams=8, |
|
temperature=1.0, |
|
top_k=None, |
|
top_p=None, |
|
|
|
no_repeat_ngram_size=2, |
|
num_return_sequences=1, |
|
return_dict_in_generate=True, |
|
output_attentions=True, |
|
output_scores=True, |
|
**gen_kwargs, |
|
) |
|
|
|
|
|
generated_answers_encoded.append(gen['sequences'][0].detach().cpu()) |
|
encoder_attentions_list.append(gen['encoder_attentions'][0].detach().cpu()) |
|
|
|
text_results = self.tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
return text_results, encoder_attentions_list, model_inputs |
|
|
|
def prepare_context_for_visualization(self, context): |
|
examples, relations = [], [] |
|
response, encoder_outputs, model_input = self.generate_based_on_context(context) |
|
input_commonsense_relations = model_input.get("input_commonsense_relations") |
|
encoder_outputs = torch.stack(encoder_outputs) |
|
n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size() |
|
print(encoder_outputs.size()) |
|
encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt) |
|
for i, ex in enumerate(encoder_attentions): |
|
d = {} |
|
indices = model_input['input_ids'][i].detach().cpu() |
|
all_tokens = self.tokenizer.convert_ids_to_tokens(indices) |
|
useful_indeces = indices != self.tokenizer.pad_token_id |
|
all_tokens = np.array(all_tokens)[useful_indeces] |
|
all_tokens = [tok.replace('Ġ', '') for tok in all_tokens] |
|
d['words'] = all_tokens |
|
d['attentions'] = ex.detach().cpu().numpy() |
|
examples.append(d) |
|
relations.append(input_commonsense_relations[i]) |
|
print(d['words']) |
|
return response, examples, relations |
|
|