Spaces:
Sleeping
Sleeping
#using pipeline to predict the input text | |
import pandas as pd | |
from transformers import pipeline, AutoTokenizer | |
import pysbd | |
#-----------------Outcome Prediction----------------- | |
def outcome(text): | |
label_mapping = { | |
'delete': [0, 'LABEL_0'], | |
'keep': [1, 'LABEL_1'], | |
'merge': [2, 'LABEL_2'], | |
'no consensus': [3, 'LABEL_3'], | |
'speedy keep': [4, 'LABEL_4'], | |
'speedy delete': [5, 'LABEL_5'], | |
'redirect': [6, 'LABEL_6'], | |
'withdrawn': [7, 'LABEL_7'] | |
} | |
model_name = "research-dump/roberta-large_deletion_multiclass_complete_final" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = pipeline("text-classification", model=model_name, return_all_scores=True) | |
# Tokenize and truncate the text | |
tokens = tokenizer(text, truncation=True, max_length=512) | |
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
results = model(truncated_text) | |
res_list = [] | |
for result in results[0]: | |
for key, value in label_mapping.items(): | |
if result['label'] == value[1]: | |
res_list.append({'sentence': truncated_text, 'outcome': key, 'score': result['score']}) | |
break | |
return res_list | |
#-----------------Stance Prediction----------------- | |
def extract_response(text, model_name, label_mapping): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
pipe = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None) | |
tokens = tokenizer(text, truncation=True, max_length=512) | |
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
results = pipe(truncated_text) | |
final_scores = {key: 0.0 for key in label_mapping} | |
for result in results[0]: | |
for key, value in label_mapping.items(): | |
if result['label'] == f'LABEL_{value}': | |
final_scores[key] = result['score'] | |
break | |
return final_scores | |
def get_stance(text): | |
label_mapping = { | |
'delete': 0, | |
'keep': 1, | |
'merge': 2, | |
'comment': 3 | |
} | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
model = 'research-dump/bert-large-uncased_wikistance_v1' | |
res_list = [] | |
for t in text_list: | |
res = extract_response(t, model,label_mapping) #, access_token) | |
highest_key = max(res, key=res.get) | |
highest_score = res[highest_key] | |
result = {'sentence':t,'stance': highest_key, 'score': highest_score} | |
res_list.append(result) | |
return res_list | |
#-----------------Policy Prediction----------------- | |
def get_policy(text): | |
label_mapping = {'Wikipedia:Notability': 0, | |
'Wikipedia:What Wikipedia is not': 1, | |
'Wikipedia:Neutral point of view': 2, | |
'Wikipedia:Verifiability': 3, | |
'Wikipedia:Wikipedia is not a dictionary': 4, | |
'Wikipedia:Wikipedia is not for things made up one day': 5, | |
'Wikipedia:Criteria for speedy deletion': 6, | |
'Wikipedia:Deletion policy': 7, | |
'Wikipedia:No original research': 8, | |
'Wikipedia:Biographies of living persons': 9, | |
'Wikipedia:Arguments to avoid in deletion discussions': 10, | |
'Wikipedia:Conflict of interest': 11, | |
'Wikipedia:Articles for deletion': 12 | |
} | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
model = 'research-dump/bert-large-uncased_wikistance_policy_v1' | |
res_list = [] | |
for t in text_list: | |
res = extract_response(t, model,label_mapping) | |
highest_key = max(res, key=res.get) | |
highest_score = res[highest_key] | |
result = {'sentence': t, 'policy': highest_key, 'score': highest_score} | |
res_list.append(result) | |
return res_list | |
#-----------------Sentiment Analysis----------------- | |
def extract_highest_score_label(res): | |
flat_res = [item for sublist in res for item in sublist] | |
highest_score_item = max(flat_res, key=lambda x: x['score']) | |
highest_score_label = highest_score_item['label'] | |
highest_score_value = highest_score_item['score'] | |
return highest_score_label, highest_score_value | |
def get_sentiment(text): | |
#sentiment analysis | |
model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = pipeline("text-classification", model=model_name, top_k= None) | |
#sentence tokenize the text using pysbd | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
res = [] | |
for t in text_list: | |
results = model(t) | |
highest_label, highest_score = extract_highest_score_label(results) | |
result = {'sentence': t,'sentiment': highest_label, 'score': highest_score} | |
res.append(result) | |
return res | |
#-----------------Toxicity Prediction----------------- | |
def get_offensive_label(text): | |
#offensive language detection model | |
model_name = "cardiffnlp/twitter-roberta-base-offensive" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = pipeline("text-classification", model=model_name, top_k= None) | |
#sentence tokenize the text using pysbd | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
res = [] | |
for t in text_list: | |
results = model(t) | |
highest_label, highest_score = extract_highest_score_label(results) | |
result = {'sentence': t,'offensive_label': highest_label, 'score': highest_score} | |
res.append(result) | |
return res | |
#create the anchor function | |
def predict_text(text, model_name): | |
if model_name == 'outcome': | |
return outcome(text) | |
elif model_name == 'stance': | |
return get_stance(text) | |
elif model_name == 'policy': | |
return get_policy(text) | |
elif model_name == 'sentiment': | |
return get_sentiment(text) | |
elif model_name == 'offensive': | |
return get_offensive_label(text) | |
else: | |
return "Invalid Task name" |