wide_analysis_space / model_predict.py
hsuvaskakoty's picture
Upload 2 files
cb58c8d verified
#using pipeline to predict the input text
import pandas as pd
from transformers import pipeline, AutoTokenizer
import pysbd
#-----------------Outcome Prediction-----------------
def outcome(text):
label_mapping = {
'delete': [0, 'LABEL_0'],
'keep': [1, 'LABEL_1'],
'merge': [2, 'LABEL_2'],
'no consensus': [3, 'LABEL_3'],
'speedy keep': [4, 'LABEL_4'],
'speedy delete': [5, 'LABEL_5'],
'redirect': [6, 'LABEL_6'],
'withdrawn': [7, 'LABEL_7']
}
model_name = "research-dump/roberta-large_deletion_multiclass_complete_final"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = pipeline("text-classification", model=model_name, return_all_scores=True)
# Tokenize and truncate the text
tokens = tokenizer(text, truncation=True, max_length=512)
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True)
results = model(truncated_text)
res_list = []
for result in results[0]:
for key, value in label_mapping.items():
if result['label'] == value[1]:
res_list.append({'sentence': truncated_text, 'outcome': key, 'score': result['score']})
break
return res_list
#-----------------Stance Prediction-----------------
def extract_response(text, model_name, label_mapping):
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipe = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None)
tokens = tokenizer(text, truncation=True, max_length=512)
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True)
results = pipe(truncated_text)
final_scores = {key: 0.0 for key in label_mapping}
for result in results[0]:
for key, value in label_mapping.items():
if result['label'] == f'LABEL_{value}':
final_scores[key] = result['score']
break
return final_scores
def get_stance(text):
label_mapping = {
'delete': 0,
'keep': 1,
'merge': 2,
'comment': 3
}
seg = pysbd.Segmenter(language="en", clean=False)
text_list = seg.segment(text)
model = 'research-dump/bert-large-uncased_wikistance_v1'
res_list = []
for t in text_list:
res = extract_response(t, model,label_mapping) #, access_token)
highest_key = max(res, key=res.get)
highest_score = res[highest_key]
result = {'sentence':t,'stance': highest_key, 'score': highest_score}
res_list.append(result)
return res_list
#-----------------Policy Prediction-----------------
def get_policy(text):
label_mapping = {'Wikipedia:Notability': 0,
'Wikipedia:What Wikipedia is not': 1,
'Wikipedia:Neutral point of view': 2,
'Wikipedia:Verifiability': 3,
'Wikipedia:Wikipedia is not a dictionary': 4,
'Wikipedia:Wikipedia is not for things made up one day': 5,
'Wikipedia:Criteria for speedy deletion': 6,
'Wikipedia:Deletion policy': 7,
'Wikipedia:No original research': 8,
'Wikipedia:Biographies of living persons': 9,
'Wikipedia:Arguments to avoid in deletion discussions': 10,
'Wikipedia:Conflict of interest': 11,
'Wikipedia:Articles for deletion': 12
}
seg = pysbd.Segmenter(language="en", clean=False)
text_list = seg.segment(text)
model = 'research-dump/bert-large-uncased_wikistance_policy_v1'
res_list = []
for t in text_list:
res = extract_response(t, model,label_mapping)
highest_key = max(res, key=res.get)
highest_score = res[highest_key]
result = {'sentence': t, 'policy': highest_key, 'score': highest_score}
res_list.append(result)
return res_list
#-----------------Sentiment Analysis-----------------
def extract_highest_score_label(res):
flat_res = [item for sublist in res for item in sublist]
highest_score_item = max(flat_res, key=lambda x: x['score'])
highest_score_label = highest_score_item['label']
highest_score_value = highest_score_item['score']
return highest_score_label, highest_score_value
def get_sentiment(text):
#sentiment analysis
model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = pipeline("text-classification", model=model_name, top_k= None)
#sentence tokenize the text using pysbd
seg = pysbd.Segmenter(language="en", clean=False)
text_list = seg.segment(text)
res = []
for t in text_list:
results = model(t)
highest_label, highest_score = extract_highest_score_label(results)
result = {'sentence': t,'sentiment': highest_label, 'score': highest_score}
res.append(result)
return res
#-----------------Toxicity Prediction-----------------
def get_offensive_label(text):
#offensive language detection model
model_name = "cardiffnlp/twitter-roberta-base-offensive"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = pipeline("text-classification", model=model_name, top_k= None)
#sentence tokenize the text using pysbd
seg = pysbd.Segmenter(language="en", clean=False)
text_list = seg.segment(text)
res = []
for t in text_list:
results = model(t)
highest_label, highest_score = extract_highest_score_label(results)
result = {'sentence': t,'offensive_label': highest_label, 'score': highest_score}
res.append(result)
return res
#create the anchor function
def predict_text(text, model_name):
if model_name == 'outcome':
return outcome(text)
elif model_name == 'stance':
return get_stance(text)
elif model_name == 'policy':
return get_policy(text)
elif model_name == 'sentiment':
return get_sentiment(text)
elif model_name == 'offensive':
return get_offensive_label(text)
else:
return "Invalid Task name"