Spaces:
Sleeping
Sleeping
#using pipeline to predict the input text | |
# from transformers import pipeline, AutoTokenizer | |
# import torch | |
# label_mapping = { | |
# 'delete': [0, 'LABEL_0'], | |
# 'keep': [1, 'LABEL_1'], | |
# 'merge': [2, 'LABEL_2'], | |
# 'no consensus': [3, 'LABEL_3'], | |
# 'speedy keep': [4, 'LABEL_4'], | |
# 'speedy delete': [5, 'LABEL_5'], | |
# 'redirect': [6, 'LABEL_6'], | |
# 'withdrawn': [7, 'LABEL_7'] | |
# } | |
# def predict_text(text, model_name): | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = pipeline("text-classification", model=model_name, return_all_scores=True) | |
# # Tokenize and truncate the text | |
# tokens = tokenizer(text, truncation=True, max_length=512) | |
# truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
# results = model(truncated_text) | |
# final_scores = {key: 0.0 for key in label_mapping} | |
# for result in results[0]: | |
# for key, value in label_mapping.items(): | |
# if result['label'] == value[1]: | |
# final_scores[key] = result['score'] | |
# break | |
# return final_scores | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
import torch | |
label_mapping = { | |
'delete': [0, 'LABEL_0'], | |
'keep': [1, 'LABEL_1'], | |
'merge': [2, 'LABEL_2'], | |
'no consensus': [3, 'LABEL_3'], | |
'speedy keep': [4, 'LABEL_4'], | |
'speedy delete': [5, 'LABEL_5'], | |
'redirect': [6, 'LABEL_6'], | |
'withdrawn': [7, 'LABEL_7'] | |
} | |
def predict_text(text, model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True) | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
final_scores = {key: 0.0 for key in label_mapping} | |
for i, score in enumerate(predictions[0]): | |
for key, value in label_mapping.items(): | |
if i == value[0]: | |
final_scores[key] = score.item() | |
break | |
# Calculate average attention | |
attentions = outputs.attentions | |
avg_attentions = torch.mean(torch.stack(attentions), dim=1) # Average over all layers | |
avg_attentions = avg_attentions.mean(dim=1)[0] # Average over heads | |
token_importance = avg_attentions.mean(dim=0) | |
# Decode tokens and highlight important ones | |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
highlighted_text = [] | |
for token, importance in zip(tokens, token_importance): | |
if importance > token_importance.mean(): | |
highlighted_text.append(f"<b>{token}</b>") # | |
else: | |
highlighted_text.append(token) | |
highlighted_text = " ".join(highlighted_text) | |
highlighted_text = highlighted_text.replace("##", "") | |
return final_scores, highlighted_text | |