Spaces:
Paused
Paused
| #%% | |
| import gradio as gr | |
| import time | |
| import sys | |
| import os | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import numpy as np | |
| import json | |
| import networkx as nx | |
| import spacy | |
| os.system("python -m spacy download en-core-web-sm==3.6.0") | |
| import pickle as pkl | |
| #%% | |
| from torch.nn.modules.loss import CrossEntropyLoss | |
| from transformers import AutoTokenizer | |
| from transformers import BioGptForCausalLM, BartForConditionalGeneration | |
| import server_utils | |
| sys.path.append("..") | |
| import Parameters | |
| from Openai.chat import generate_abstract | |
| sys.path.append("../DiseaseSpecific") | |
| import utils, attack | |
| from attack import calculate_edge_bound, get_model_loss_without_softmax | |
| specific_model = None | |
| def capitalize_the_first_letter(s): | |
| return s[0].upper() + s[1:] | |
| parser = utils.get_argument_parser() | |
| parser = utils.add_attack_parameters(parser) | |
| parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study | |
| args = parser.parse_args() | |
| args = utils.set_hyperparams(args) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # device = torch.device("cpu") | |
| args.device = device | |
| args.device1 = device | |
| if torch.cuda.device_count() >= 2: | |
| args.device = "cuda:0" | |
| args.device1 = "cuda:1" | |
| utils.seed_all(args.seed) | |
| np.set_printoptions(precision=5) | |
| cudnn.benchmark = False | |
| model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop) | |
| model_path = '../DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name) | |
| data_path = os.path.join('../DiseaseSpecific/processed_data', args.data) | |
| data = utils.load_data(os.path.join(data_path, 'all.txt')) | |
| n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) | |
| with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl: | |
| filters = pkl.load(fl) | |
| with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl: | |
| entityid_to_nodetype = json.load(fl) | |
| with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl: | |
| edge_nghbrs = pkl.load(fl) | |
| with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl: | |
| disease_meshid = pkl.load(fl) | |
| with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: | |
| entity_to_id = json.load(fl) | |
| with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
| entity_raw_name = pkl.load(fl) | |
| with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl: | |
| id_to_entity = json.load(fl) | |
| id_to_meshid = id_to_entity.copy() | |
| with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: | |
| retieve_sentence_through_edgetype = pkl.load(fl) | |
| with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: | |
| raw_text_sen = pkl.load(fl) | |
| with open(Parameters.UMLSfile+'drug_term', 'rb') as fl: | |
| drug_term = pkl.load(fl) | |
| drug_dict = {} | |
| disease_dict = {} | |
| for k, v in entity_raw_name.items(): | |
| #chemical_mesh:c050048 | |
| tp = k.split('_')[0] | |
| v = capitalize_the_first_letter(v) | |
| if len(v) <= 2: | |
| continue | |
| if tp == 'chemical': | |
| drug_dict[v] = k | |
| elif tp == 'disease': | |
| disease_dict[v] = k | |
| drug_list = list(drug_dict.keys()) | |
| disease_list = list(disease_dict.keys()) | |
| drug_list.sort() | |
| disease_list.sort() | |
| init_mask = np.asarray([0] * n_ent).astype('int64') | |
| init_mask = (init_mask == 1) | |
| for k, v in filters.items(): | |
| for kk, vv in v.items(): | |
| tmp = init_mask.copy() | |
| tmp[np.asarray(vv)] = True | |
| t = torch.ByteTensor(tmp).to(args.device) | |
| filters[k][kk] = t | |
| gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') | |
| gpt_tokenizer.pad_token = gpt_tokenizer.eos_token | |
| gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id) | |
| gpt_model.eval() | |
| specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device) | |
| specific_model.eval() | |
| divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent) | |
| nlp = spacy.load("en_core_web_sm") | |
| bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') | |
| bart_model.eval() | |
| bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') | |
| def tune_chatgpt(draft, attack_data, dpath): | |
| dpath_i = 0 | |
| bart_model.to(args.device1) | |
| for i, v in enumerate(draft): | |
| input = v['in'].replace('\n', '') | |
| output = v['out'].replace('\n', '') | |
| s, r, o = attack_data[i] | |
| path_text = dpath[dpath_i].replace('\n', '') | |
| dpath_i += 1 | |
| text_s = entity_raw_name[id_to_meshid[s]] | |
| text_o = entity_raw_name[id_to_meshid[o]] | |
| doc = nlp(output) | |
| words= input.split(' ') | |
| tokenized_sens = [sen for sen in doc.sents] | |
| sens = np.array([sen.text for sen in doc.sents]) | |
| checkset = set([text_s, text_o]) | |
| e_entity = set(['start_entity', 'end_entity']) | |
| for path in path_text.split(' '): | |
| a, b, c = path.split('|') | |
| if a not in e_entity: | |
| checkset.add(a) | |
| if c not in e_entity: | |
| checkset.add(c) | |
| vec = [] | |
| l = 0 | |
| while(l < len(words)): | |
| bo =False | |
| for j in range(len(words), l, -1): # reversing is important !!! | |
| cc = ' '.join(words[l:j]) | |
| if (cc in checkset): | |
| vec += [True] * (j-l) | |
| l = j | |
| bo = True | |
| break | |
| if not bo: | |
| vec.append(False) | |
| l += 1 | |
| vec, span = server_utils.find_mini_span(vec, words, checkset) | |
| # vec = np.vectorize(lambda x: x in checkset)(words) | |
| vec[-1] = True | |
| prompt = [] | |
| mask_num = 0 | |
| for j, bo in enumerate(vec): | |
| if not bo: | |
| mask_num += 1 | |
| else: | |
| if mask_num > 0: | |
| # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) | |
| mask_num = max(mask_num, 1) | |
| mask_num= min(8, mask_num) | |
| prompt += ['<mask>'] * mask_num | |
| prompt.append(words[j]) | |
| mask_num = 0 | |
| prompt = ' '.join(prompt) | |
| Text = [] | |
| Assist = [] | |
| for j in range(len(sens)): | |
| Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) | |
| assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
| Text.append(' '.join(Bart_input)) | |
| Assist.append(' '.join(assist)) | |
| for j in range(len(sens)): | |
| Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:]) | |
| assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
| Text.append(' '.join(Bart_input)) | |
| Assist.append(' '.join(assist)) | |
| batch_size = 8 | |
| Outs = [] | |
| for l in range(0, len(Text), batch_size): | |
| R = min(len(Text), l + batch_size) | |
| A = bart_tokenizer(Text[l:R], | |
| truncation = True, | |
| padding = True, | |
| max_length = 1024, | |
| return_tensors="pt") | |
| input_ids = A['input_ids'].to(args.device1) | |
| attention_mask = A['attention_mask'].to(args.device1) | |
| aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024) | |
| outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| Outs += outs | |
| bart_model.to('cpu') | |
| return span, prompt, Outs, Text, Assist | |
| def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v): | |
| criterion = CrossEntropyLoss(reduction="none") | |
| text_s = entity_raw_name[id_to_meshid[s]] | |
| text_o = entity_raw_name[id_to_meshid[o]] | |
| sen_list = [server_utils.process(text) for text in sen_list] | |
| path_text = dpath[0].replace('\n', '') | |
| checkset = set([text_s, text_o]) | |
| e_entity = set(['start_entity', 'end_entity']) | |
| for path in path_text.split(' '): | |
| a, b, c = path.split('|') | |
| if a not in e_entity: | |
| checkset.add(a) | |
| if c not in e_entity: | |
| checkset.add(c) | |
| input = v['in'].replace('\n', '') | |
| output = v['out'].replace('\n', '') | |
| doc = nlp(output) | |
| gpt_sens = [sen.text for sen in doc.sents] | |
| assert len(gpt_sens) == len(sen_list) // 2 | |
| word_sets = [] | |
| for sen in gpt_sens: | |
| word_sets.append(set(sen.split(' '))) | |
| def sen_align(word_sets, modified_word_sets): | |
| l = 0 | |
| while(l < len(modified_word_sets)): | |
| if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: | |
| l += 1 | |
| else: | |
| break | |
| if l == len(modified_word_sets): | |
| return -1, -1, -1, -1 | |
| r = l + 1 | |
| r1 = None | |
| r2 = None | |
| for pos1 in range(r, len(word_sets)): | |
| for pos2 in range(r, len(modified_word_sets)): | |
| if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: | |
| r1 = pos1 | |
| r2 = pos2 | |
| break | |
| if r1 is not None: | |
| break | |
| if r1 is None: | |
| r1 = len(word_sets) | |
| r2 = len(modified_word_sets) | |
| return l, r1, l, r2 | |
| replace_sen_list = [] | |
| boundary = [] | |
| assert len(sen_list) % 2 == 0 | |
| for j in range(len(sen_list) // 2): | |
| doc = nlp(sen_list[j]) | |
| sens = [sen.text for sen in doc.sents] | |
| modified_word_sets = [set(sen.split(' ')) for sen in sens] | |
| l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) | |
| boundary.append((l1, r1, l2, r2)) | |
| if l1 == -1: | |
| replace_sen_list.append(sen_list[j]) | |
| continue | |
| check_text = ' '.join(sens[l2: r2]) | |
| replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) | |
| sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] | |
| gpt_model.to(args.device1) | |
| sen_list.append(output) | |
| tokens = gpt_tokenizer( sen_list, | |
| truncation = True, | |
| padding = True, | |
| max_length = 1024, | |
| return_tensors="pt") | |
| target_ids = tokens['input_ids'].to(args.device1) | |
| attention_mask = tokens['attention_mask'].to(args.device1) | |
| L = len(sen_list) | |
| ret_log_L = [] | |
| for l in range(0, L, 5): | |
| R = min(L, l + 5) | |
| target = target_ids[l:R, :] | |
| attention = attention_mask[l:R, :] | |
| outputs = gpt_model(input_ids = target, | |
| attention_mask = attention, | |
| labels = target) | |
| logits = outputs.logits | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = target[..., 1:].contiguous() | |
| Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
| Loss = Loss.view(-1, shift_logits.shape[1]) | |
| attention = attention[..., 1:].contiguous() | |
| log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
| ret_log_L.append(log_Loss.detach()) | |
| log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() | |
| gpt_model.to('cpu') | |
| p = np.argmin(log_Loss) | |
| return sen_list[p] | |
| def generate_template_for_triplet(attack_data): | |
| criterion = CrossEntropyLoss(reduction="none") | |
| gpt_model.to(args.device1) | |
| print('Generating template ...') | |
| GPT_batch_size = 8 | |
| single_sentence = [] | |
| test_text = [] | |
| test_dp = [] | |
| test_parse = [] | |
| s, r, o = attack_data[0] | |
| dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
| candidate_sen = [] | |
| Dp_path = [] | |
| L = len(dependency_sen_dict.keys()) | |
| bound = 500 // L | |
| if bound == 0: | |
| bound = 1 | |
| for dp_path, sen_list in dependency_sen_dict.items(): | |
| if len(sen_list) > bound: | |
| index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) | |
| sen_list = [sen_list[aa] for aa in index] | |
| ssen_list = [] | |
| for aa in range(len(sen_list)): | |
| paper_id, sen_id = sen_list[aa] | |
| if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']: | |
| continue | |
| ssen_list.append(sen_list[aa]) | |
| sen_list = ssen_list | |
| candidate_sen += sen_list | |
| Dp_path += [dp_path] * len(sen_list) | |
| text_s = entity_raw_name[id_to_meshid[s]] | |
| text_o = entity_raw_name[id_to_meshid[o]] | |
| candidate_text_sen = [] | |
| candidate_ori_sen = [] | |
| candidate_parse_sen = [] | |
| for paper_id, sen_id in candidate_sen: | |
| sen = raw_text_sen[paper_id][sen_id] | |
| text = sen['text'] | |
| candidate_ori_sen.append(text) | |
| ss = sen['start_formatted'] | |
| oo = sen['end_formatted'] | |
| text = text.replace('-LRB-', '(') | |
| text = text.replace('-RRB-', ')') | |
| text = text.replace('-LSB-', '[') | |
| text = text.replace('-RSB-', ']') | |
| text = text.replace('-LCB-', '{') | |
| text = text.replace('-RCB-', '}') | |
| parse_text = text | |
| parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) | |
| parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) | |
| text = text.replace(ss, text_s) | |
| text = text.replace(oo, text_o) | |
| text = text.replace('_', ' ') | |
| candidate_text_sen.append(text) | |
| candidate_parse_sen.append(parse_text) | |
| tokens = gpt_tokenizer( candidate_text_sen, | |
| truncation = True, | |
| padding = True, | |
| max_length = 300, | |
| return_tensors="pt") | |
| target_ids = tokens['input_ids'].to(args.device1) | |
| attention_mask = tokens['attention_mask'].to(args.device1) | |
| L = len(candidate_text_sen) | |
| assert L > 0 | |
| ret_log_L = [] | |
| for l in range(0, L, GPT_batch_size): | |
| R = min(L, l + GPT_batch_size) | |
| target = target_ids[l:R, :] | |
| attention = attention_mask[l:R, :] | |
| outputs = gpt_model(input_ids = target, | |
| attention_mask = attention, | |
| labels = target) | |
| logits = outputs.logits | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = target[..., 1:].contiguous() | |
| Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
| Loss = Loss.view(-1, shift_logits.shape[1]) | |
| attention = attention[..., 1:].contiguous() | |
| log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
| ret_log_L.append(log_Loss.detach()) | |
| ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) | |
| sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) | |
| sen_score.sort(key = lambda x: x[1]) | |
| test_text.append(sen_score[0][2]) | |
| test_dp.append(sen_score[0][3]) | |
| test_parse.append(sen_score[0][4]) | |
| single_sentence.append(sen_score[0][0]) | |
| gpt_model.to('cpu') | |
| return single_sentence, test_text, test_dp, test_parse | |
| meshids = list(id_to_meshid.values()) | |
| cal = { | |
| 'chemical' : 0, | |
| 'disease' : 0, | |
| 'gene' : 0 | |
| } | |
| for meshid in meshids: | |
| cal[meshid.split('_')[0]] += 1 | |
| def check_reasonable(s, r, o): | |
| train_trip = np.asarray([[s, r, o]]) | |
| train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
| edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
| # edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1)) | |
| edge_loss = edge_loss.item() | |
| edge_loss = (edge_loss - data_mean) / data_std | |
| edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) ) | |
| bound = 1 - args.reasonable_rate | |
| return (edge_losses_prob > bound), edge_losses_prob | |
| edgeid_to_edgetype = {} | |
| edgeid_to_reversemask = {} | |
| for k, id_list in Parameters.edge_type_to_id.items(): | |
| for iid, mask in zip(id_list, Parameters.reverse_mask[k]): | |
| edgeid_to_edgetype[str(iid)] = k | |
| edgeid_to_reversemask[str(iid)] = mask | |
| reverse_tot = 0 | |
| G = nx.DiGraph() | |
| for s, r, o in data: | |
| assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0] | |
| if edgeid_to_reversemask[r] == 1: | |
| reverse_tot += 1 | |
| G.add_edge(int(o), int(s)) | |
| else: | |
| G.add_edge(int(s), int(o)) | |
| print('Page ranking ...') | |
| pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) | |
| drug_meshid = [] | |
| drug_list = [] | |
| for meshid, nm in entity_raw_name.items(): | |
| if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical': | |
| drug_meshid.append(meshid) | |
| drug_list.append(capitalize_the_first_letter(nm)) | |
| drug_list = list(set(drug_list)) | |
| drug_list.sort() | |
| drug_meshid = set(drug_meshid) | |
| pr = list(pagerank_value_1.items()) | |
| pr.sort(key = lambda x: x[1]) | |
| sorted_rank = { 'chemical' : [], | |
| 'gene' : [], | |
| 'disease': [], | |
| 'merged' : []} | |
| for iid, score in pr: | |
| tp = id_to_meshid[str(iid)].split('_')[0] | |
| if tp == 'chemical': | |
| if id_to_meshid[str(iid)] in drug_meshid: | |
| sorted_rank[tp].append((iid, score)) | |
| else: | |
| sorted_rank[tp].append((iid, score)) | |
| sorted_rank['merged'].append((iid, score)) | |
| llen = len(sorted_rank['merged']) | |
| sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ] | |
| def generate_specific_attack_edge(start_entity, end_entity): | |
| global specific_model | |
| specific_model.to(device) | |
| strat_meshid = drug_dict[start_entity] | |
| end_meshid = disease_dict[end_entity] | |
| start_entity = entity_to_id[strat_meshid] | |
| end_entity = entity_to_id[end_meshid] | |
| target_data = np.array([[start_entity, '10', end_entity]]) | |
| neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args) | |
| ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...' | |
| param_optimizer = list(specific_model.named_parameters()) | |
| param_influence = [] | |
| for n,p in param_optimizer: | |
| param_influence.append(p) | |
| len_list = [] | |
| for v in neighbors.values(): | |
| len_list.append(len(v)) | |
| mean_len = np.mean(len_list) | |
| attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False) | |
| s, r, o = attack_trip[0] | |
| specific_model.to('cpu') | |
| return s, r, o | |
| def generate_agnostic_attack_edge(targets): | |
| specific_model.to(device) | |
| attack_edge_list = [] | |
| for target in targets: | |
| candidate_list = [] | |
| score_list = [] | |
| loss_list = [] | |
| main_dict = {} | |
| for iid, score in sorted_rank['merged']: | |
| a = G.number_of_edges(iid, target) + 1 | |
| if a != 1: | |
| continue | |
| b = G.out_degree(iid) + 1 | |
| tp = id_to_meshid[str(iid)].split('_')[0] | |
| edge_losses = [] | |
| r_list = [] | |
| for r in range(len(edgeid_to_edgetype)): | |
| r_tp = edgeid_to_edgetype[str(r)] | |
| if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'): | |
| train_trip = np.array([[iid, r, target]]) | |
| train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
| edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
| edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
| r_list.append(r) | |
| elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp): | |
| train_trip = np.array([[iid, r, target]]) # add batch dim | |
| train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
| edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
| edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
| r_list.append(r) | |
| if len(edge_losses)==0: | |
| continue | |
| min_index = torch.argmin(torch.cat(edge_losses, dim = 0)) | |
| r = r_list[min_index] | |
| r_tp = edgeid_to_edgetype[str(r)] | |
| old_len = len(candidate_list) | |
| if (edgeid_to_reversemask[str(r)] == 0): | |
| bo, prob = check_reasonable(iid, r, target) | |
| if bo: | |
| candidate_list.append((iid, r, target)) | |
| score_list.append(score * a / b) | |
| loss_list.append(edge_losses[min_index].item()) | |
| if (edgeid_to_reversemask[str(r)] == 1): | |
| bo, prob = check_reasonable(target, r, iid) | |
| if bo: | |
| candidate_list.append((target, r, iid)) | |
| score_list.append(score * a / b) | |
| loss_list.append(edge_losses[min_index].item()) | |
| if len(candidate_list) == 0: | |
| if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
| attack_edge_list.append((-1,-1,-1)) | |
| else: | |
| attack_edge_list.append([]) | |
| continue | |
| norm_score = np.array(score_list) / np.sum(score_list) | |
| norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list))) | |
| total_score = norm_score * norm_loss | |
| total_score_index = list(zip(range(len(total_score)), total_score)) | |
| total_score_index.sort(key = lambda x: x[1], reverse = True) | |
| total_index = np.argsort(total_score)[::-1] | |
| assert total_index[0] == total_score_index[0][0] | |
| # find rank of main index | |
| max_index = np.argmax(total_score) | |
| assert max_index == total_score_index[0][0] | |
| tmp_add = [] | |
| add_num = 1 | |
| if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
| attack_edge_list.append(candidate_list[max_index]) | |
| else: | |
| add_num = int(args.added_edge_num) | |
| for i in range(add_num): | |
| tmp_add.append(candidate_list[total_score_index[i][0]]) | |
| attack_edge_list.append(tmp_add) | |
| specific_model.to('cpu') | |
| return attack_edge_list[0] | |
| def specific_func(start_entity, end_entity): | |
| args.reasonable_rate = 0.5 | |
| s, r, o = generate_specific_attack_edge(start_entity, end_entity) | |
| if int(s) == -1: | |
| return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
| s_name = entity_raw_name[id_to_entity[str(s)]] | |
| r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
| o_name = entity_raw_name[id_to_entity[str(o)]] | |
| attack_data = np.array([[s, r, o]]) | |
| path_list = [] | |
| with open(f'../DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl: | |
| for line in fl.readlines(): | |
| line.replace('\n', '') | |
| path_list.append(line) | |
| with open(f'../DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl: | |
| sentence_dict = json.load(fl) | |
| dpath = [] | |
| for k, v in sentence_dict.items(): | |
| if f'{s}_{r}_{o}' in k: | |
| single_sentence = [v] | |
| dpath = [path_list[int(k.split('_')[-1])]] | |
| break | |
| if len(dpath) == 0: | |
| single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
| elif not(s_name in single_sentence[0] and o_name in single_sentence[0]): | |
| single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
| print('Using ChatGPT for generation...') | |
| draft = generate_abstract(single_sentence[0]) | |
| print('Using BioBART for tuning...') | |
| span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) | |
| text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) | |
| return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
| # f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}' | |
| def agnostic_func(agnostic_entity): | |
| args.reasonable_rate = 0.7 | |
| target_id = entity_to_id[drug_dict[agnostic_entity]] | |
| s = generate_agnostic_attack_edge([int(target_id)]) | |
| if len(s) == 0: | |
| return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
| if int(s[0]) == -1: | |
| return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
| s, r, o = str(s[0]), str(s[1]), str(s[2]) | |
| s_name = entity_raw_name[id_to_entity[str(s)]] | |
| r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
| o_name = entity_raw_name[id_to_entity[str(o)]] | |
| attack_data = np.array([[s, r, o]]) | |
| single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
| print('Using ChatGPT for generation...') | |
| draft = generate_abstract(single_sentence[0]) | |
| print('Using BioBART for tuning...') | |
| span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) | |
| text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) | |
| return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
| #%% | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("Poison scitific knowledge with Scorpius") | |
| # with gr.Column(): | |
| with gr.Row(): | |
| # Center | |
| with gr.Column(): | |
| gr.Markdown("Select your poison target") | |
| with gr.Tab('Target specific'): | |
| with gr.Column(): | |
| with gr.Row(): | |
| start_entity = gr.Dropdown(drug_list, label="Promoting drug") | |
| end_entity = gr.Dropdown(disease_list, label="Target disease") | |
| specific_generation_button = gr.Button('Poison!') | |
| with gr.Tab('Target agnostic'): | |
| agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug") | |
| agnostic_generation_button = gr.Button('Poison!') | |
| with gr.Column(): | |
| gr.Markdown("Malicious link") | |
| malicisous_link = gr.Textbox(lines=1, label="Malicious link") | |
| gr.Markdown("Malicious text") | |
| malicious_text = gr.Textbox(label="Malicious text", lines=5) | |
| specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text]) | |
| agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text]) | |
| demo.launch(server_name="0.0.0.0", server_port=8000, debug=False) |