wide_analysis_space / model_predict.py
hsuvaskakoty's picture
Upload model_predict.py
9a3ff71 verified
raw
history blame
3 kB
#using pipeline to predict the input text
from transformers import pipeline, AutoTokenizer
import torch
label_mapping = {
'delete': [0, 'LABEL_0'],
'keep': [1, 'LABEL_1'],
'merge': [2, 'LABEL_2'],
'no consensus': [3, 'LABEL_3'],
'speedy keep': [4, 'LABEL_4'],
'speedy delete': [5, 'LABEL_5'],
'redirect': [6, 'LABEL_6'],
'withdrawn': [7, 'LABEL_7']
}
def predict_text(text, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = pipeline("text-classification", model=model_name, return_all_scores=True)
# Tokenize and truncate the text
tokens = tokenizer(text, truncation=True, max_length=512)
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True)
results = model(truncated_text)
final_scores = {key: 0.0 for key in label_mapping}
for result in results[0]:
for key, value in label_mapping.items():
if result['label'] == value[1]:
final_scores[key] = result['score']
break
return final_scores
# from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# import torch
# label_mapping = {
# 'delete': [0, 'LABEL_0'],
# 'keep': [1, 'LABEL_1'],
# 'merge': [2, 'LABEL_2'],
# 'no consensus': [3, 'LABEL_3'],
# 'speedy keep': [4, 'LABEL_4'],
# 'speedy delete': [5, 'LABEL_5'],
# 'redirect': [6, 'LABEL_6'],
# 'withdrawn': [7, 'LABEL_7']
# }
# def predict_text(text, model_name):
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
# inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# outputs = model(**inputs)
# predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# final_scores = {key: 0.0 for key in label_mapping}
# for i, score in enumerate(predictions[0]):
# for key, value in label_mapping.items():
# if i == value[0]:
# final_scores[key] = score.item()
# break
# # Calculate average attention
# attentions = outputs.attentions
# avg_attentions = torch.mean(torch.stack(attentions), dim=1) # Average over all layers
# avg_attentions = avg_attentions.mean(dim=1)[0] # Average over heads
# token_importance = avg_attentions.mean(dim=0)
# # Decode tokens and highlight important ones
# tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# highlighted_text = []
# for token, importance in zip(tokens, token_importance):
# if importance > token_importance.mean():
# highlighted_text.append(f"<b>{token}</b>") #
# else:
# highlighted_text.append(token)
# highlighted_text = " ".join(highlighted_text)
# highlighted_text = highlighted_text.replace("##", "")
# return final_scores, highlighted_text