In [ ]:
import os; os.chdir('..')
from transformers import pipeline
classifier = pipeline("text-classification", model="finetuned_entity_categorical_classification/checkpoint-23355", device="cuda")
In [ ]:
classifier(
'cat ear shaped headphones'
)
In [ ]:
classifier(
'catfood'
)
In [ ]:
classifier(
'headphones'
)
In [ ]:
Inference Without Pipes¶
In [ ]:
import os; os.chdir('..')
%pwd
Out[ ]:
'/home/ubuntu/SentenceStructureComparision'
In [ ]:
import json
label2id= json.load(
open('data/categories_refined.json', 'r')
)
id2label= {}
for key in label2id.keys():
id2label[label2id[key]] = key
In [ ]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
import torch
from torch.nn import functional as F
model_name= "finetuned_entity_categorical_classification/checkpoint-3212"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
In [ ]:
def predict(sentence: str):
'''
Returns (probability_human, probability_AI, label)
'''
inputs = tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
# print("logits: ", logits)
predicted_class_id = logits.argmax().item()
# get probabilities using softmax from logit score and convert it to numpy array
probabilities_scores = F.softmax(logits, dim = -1).numpy()[0]
d= {}
for i in range(27):
# print(f"P({id2label[i]}): {probabilities_scores[i]}")
# d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')
d[f'P({id2label[i]})']= round(probabilities_scores[i], 3)
print("Predicted Class: ", model.config.id2label[predicted_class_id], f"\nprobabilities_scores: {probabilities_scores[predicted_class_id]}\n")
return d
In [ ]:
predict("cat ear headphones")
Predicted Class: Computers_and_Electronics probabilities_scores: 0.9997648596763611
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 1.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict('catfood')
Predicted Class: Food_and_Drink probabilities_scores: 0.9993139505386353
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 0.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.999, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict("food for cats")
Predicted Class: Food_and_Drink probabilities_scores: 0.9997541308403015
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 0.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 1.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict('cat edible foods')
Predicted Class: Food_and_Drink probabilities_scores: 0.9963496923446655
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.002, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 0.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.996, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict('feline ear shaped headphones')
Predicted Class: Computers_and_Electronics probabilities_scores: 0.999832034111023
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 1.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict("apple ")
Predicted Class: Food_and_Drink probabilities_scores: 0.5473537445068359
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 0.448, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.001, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.547, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.002, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict('apple iphone')
Predicted Class: Computers_and_Electronics probabilities_scores: 0.9997270703315735
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 1.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict(
'razer kraken'
)
Predicted Class: Computers_and_Electronics probabilities_scores: 0.9997072815895081
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 1.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict("facebook")
Predicted Class: Online Communities probabilities_scores: 0.997126042842865
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 0.001, 'P(Online Communities)': 0.997, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.001, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]:
predict('apple iphone')
Predicted Class: Computers_and_Electronics probabilities_scores: 0.9997270703315735
Out[ ]:
{'P(Hobbies_and_Leisure)': 0.0, 'P(News)': 0.0, 'P(Science)': 0.0, 'P(Autos_and_Vehicles)': 0.0, 'P(Health)': 0.0, 'P(Pets_and_Animals)': 0.0, 'P(Adult)': 0.0, 'P(Computers_and_Electronics)': 1.0, 'P(Online Communities)': 0.0, 'P(Beauty_and_Fitness)': 0.0, 'P(People_and_Society)': 0.0, 'P(Business_and_Industrial)': 0.0, 'P(Reference)': 0.0, 'P(Shopping)': 0.0, 'P(Travel_and_Transportation)': 0.0, 'P(Food_and_Drink)': 0.0, 'P(Law_and_Government)': 0.0, 'P(Books_and_Literature)': 0.0, 'P(Finance)': 0.0, 'P(Games)': 0.0, 'P(Home_and_Garden)': 0.0, 'P(Jobs_and_Education)': 0.0, 'P(Arts_and_Entertainment)': 0.0, 'P(Sensitive Subjects)': 0.0, 'P(Real Estate)': 0.0, 'P(Internet_and_Telecom)': 0.0, 'P(Sports)': 0.0}
In [ ]: