{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os; os.chdir('..')\n", "from transformers import pipeline\n", "\n", "classifier = pipeline(\"text-classification\", model=\"finetuned_entity_categorical_classification/checkpoint-23355\", device=\"cuda\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classifier(\n", " 'cat ear shaped headphones'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classifier(\n", " 'catfood'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classifier(\n", " 'headphones'\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference Without Pipes" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/ubuntu/SentenceStructureComparision'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os; os.chdir('..')\n", "%pwd\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "label2id= json.load(\n", " open('data/categories_refined.json', 'r')\n", ")\n", "id2label= {}\n", "for key in label2id.keys():\n", " id2label[label2id[key]] = key\n", " " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "import torch\n", "from torch.nn import functional as F\n", "import numpy as np\n", "\n", "\n", "\n", "model_name= \"finetuned_entity_categorical_classification/checkpoint-3338\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# probabilities = 1 / (1 + np.exp(-logit_score))\n", "def logit2prob(logit):\n", " # odds =np.exp(logit)\n", " # prob = odds / (1 + odds)\n", " prob= 1/(1+ np.exp(-logit))\n", " return np.round(prob, 3)\n", "\n", "\n", "\n", "\n", "def predict(sentence: str):\n", " '''\n", " Returns (probability_human, probability_AI, label)\n", " '''\n", " inputs = tokenizer(sentence, return_tensors=\"pt\")\n", " with torch.no_grad():\n", " logits = model(**inputs).logits\n", " \n", " # print(\"logits: \", logits)\n", " predicted_class_id = logits.argmax().item()\n", " \n", " # get probabilities using softmax from logit score and convert it to numpy array\n", " probabilities_scores = F.softmax(logits, dim = -1).numpy()[0]\n", " individual_probabilities_scores = logit2prob(logits.numpy()[0])\n", " \n", " \n", " d= {}\n", " d_ind= {}\n", " # d_ind= {}\n", " for i in range(27):\n", " # print(f\"P({id2label[i]}): {probabilities_scores[i]}\")\n", " # d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')\n", " d[f'P({id2label[i]})']= round(probabilities_scores[i], 3)\n", " \n", " \n", " for i in range(27):\n", " # print(f\"P({id2label[i]}): {probabilities_scores[i]}\")\n", " # d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')\n", " d_ind[f'P({id2label[i]})']= (individual_probabilities_scores[i])\n", " \n", " \n", "\n", " print(\"Predicted Class: \", model.config.id2label[predicted_class_id], f\"\\nprobabilities_scores: {individual_probabilities_scores[predicted_class_id]}\\n\")\n", " return d_ind\n", " \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.071,\n", " 'P(News)': 0.004,\n", " 'P(Science)': 0.02,\n", " 'P(Autos_and_Vehicles)': 0.023,\n", " 'P(Health)': 0.012,\n", " 'P(Pets_and_Animals)': 0.005,\n", " 'P(Adult)': 0.062,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.076,\n", " 'P(Beauty_and_Fitness)': 0.013,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.003,\n", " 'P(Reference)': 0.043,\n", " 'P(Shopping)': 0.233,\n", " 'P(Travel_and_Transportation)': 0.003,\n", " 'P(Food_and_Drink)': 0.013,\n", " 'P(Law_and_Government)': 0.06,\n", " 'P(Books_and_Literature)': 0.004,\n", " 'P(Finance)': 0.035,\n", " 'P(Games)': 0.044,\n", " 'P(Home_and_Garden)': 0.013,\n", " 'P(Jobs_and_Education)': 0.003,\n", " 'P(Arts_and_Entertainment)': 0.012,\n", " 'P(Sensitive Subjects)': 0.003,\n", " 'P(Real Estate)': 0.027,\n", " 'P(Internet_and_Telecom)': 0.045,\n", " 'P(Sports)': 0.016}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"cat ear headphones\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Food_and_Drink \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.025,\n", " 'P(News)': 0.137,\n", " 'P(Science)': 0.023,\n", " 'P(Autos_and_Vehicles)': 0.01,\n", " 'P(Health)': 0.272,\n", " 'P(Pets_and_Animals)': 0.148,\n", " 'P(Adult)': 0.005,\n", " 'P(Computers_and_Electronics)': 0.089,\n", " 'P(Online Communities)': 0.072,\n", " 'P(Beauty_and_Fitness)': 0.105,\n", " 'P(People_and_Society)': 0.005,\n", " 'P(Business_and_Industrial)': 0.011,\n", " 'P(Reference)': 0.011,\n", " 'P(Shopping)': 0.037,\n", " 'P(Travel_and_Transportation)': 0.016,\n", " 'P(Food_and_Drink)': 1.0,\n", " 'P(Law_and_Government)': 0.006,\n", " 'P(Books_and_Literature)': 0.024,\n", " 'P(Finance)': 0.013,\n", " 'P(Games)': 0.044,\n", " 'P(Home_and_Garden)': 0.012,\n", " 'P(Jobs_and_Education)': 0.011,\n", " 'P(Arts_and_Entertainment)': 0.161,\n", " 'P(Sensitive Subjects)': 0.032,\n", " 'P(Real Estate)': 0.006,\n", " 'P(Internet_and_Telecom)': 0.009,\n", " 'P(Sports)': 0.02}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict('catfood')" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Food_and_Drink \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.048,\n", " 'P(News)': 0.202,\n", " 'P(Science)': 0.025,\n", " 'P(Autos_and_Vehicles)': 0.095,\n", " 'P(Health)': 0.094,\n", " 'P(Pets_and_Animals)': 0.006,\n", " 'P(Adult)': 0.016,\n", " 'P(Computers_and_Electronics)': 0.129,\n", " 'P(Online Communities)': 0.078,\n", " 'P(Beauty_and_Fitness)': 0.122,\n", " 'P(People_and_Society)': 0.008,\n", " 'P(Business_and_Industrial)': 0.022,\n", " 'P(Reference)': 0.014,\n", " 'P(Shopping)': 0.046,\n", " 'P(Travel_and_Transportation)': 0.024,\n", " 'P(Food_and_Drink)': 1.0,\n", " 'P(Law_and_Government)': 0.013,\n", " 'P(Books_and_Literature)': 0.038,\n", " 'P(Finance)': 0.026,\n", " 'P(Games)': 0.091,\n", " 'P(Home_and_Garden)': 0.025,\n", " 'P(Jobs_and_Education)': 0.033,\n", " 'P(Arts_and_Entertainment)': 0.233,\n", " 'P(Sensitive Subjects)': 0.022,\n", " 'P(Real Estate)': 0.005,\n", " 'P(Internet_and_Telecom)': 0.003,\n", " 'P(Sports)': 0.039}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"food for cats\")" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Food_and_Drink \n", "probabilities_scores: 0.9980000257492065\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.113,\n", " 'P(News)': 0.037,\n", " 'P(Science)': 0.024,\n", " 'P(Autos_and_Vehicles)': 0.05,\n", " 'P(Health)': 0.039,\n", " 'P(Pets_and_Animals)': 0.444,\n", " 'P(Adult)': 0.003,\n", " 'P(Computers_and_Electronics)': 0.022,\n", " 'P(Online Communities)': 0.12,\n", " 'P(Beauty_and_Fitness)': 0.114,\n", " 'P(People_and_Society)': 0.001,\n", " 'P(Business_and_Industrial)': 0.008,\n", " 'P(Reference)': 0.003,\n", " 'P(Shopping)': 0.014,\n", " 'P(Travel_and_Transportation)': 0.009,\n", " 'P(Food_and_Drink)': 0.998,\n", " 'P(Law_and_Government)': 0.005,\n", " 'P(Books_and_Literature)': 0.006,\n", " 'P(Finance)': 0.009,\n", " 'P(Games)': 0.052,\n", " 'P(Home_and_Garden)': 0.006,\n", " 'P(Jobs_and_Education)': 0.005,\n", " 'P(Arts_and_Entertainment)': 0.199,\n", " 'P(Sensitive Subjects)': 0.033,\n", " 'P(Real Estate)': 0.003,\n", " 'P(Internet_and_Telecom)': 0.001,\n", " 'P(Sports)': 0.123}" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict('cat edible foods')" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.134,\n", " 'P(News)': 0.002,\n", " 'P(Science)': 0.027,\n", " 'P(Autos_and_Vehicles)': 0.061,\n", " 'P(Health)': 0.008,\n", " 'P(Pets_and_Animals)': 0.006,\n", " 'P(Adult)': 0.069,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.16,\n", " 'P(Beauty_and_Fitness)': 0.015,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.003,\n", " 'P(Reference)': 0.019,\n", " 'P(Shopping)': 0.147,\n", " 'P(Travel_and_Transportation)': 0.005,\n", " 'P(Food_and_Drink)': 0.023,\n", " 'P(Law_and_Government)': 0.115,\n", " 'P(Books_and_Literature)': 0.007,\n", " 'P(Finance)': 0.037,\n", " 'P(Games)': 0.042,\n", " 'P(Home_and_Garden)': 0.032,\n", " 'P(Jobs_and_Education)': 0.003,\n", " 'P(Arts_and_Entertainment)': 0.01,\n", " 'P(Sensitive Subjects)': 0.003,\n", " 'P(Real Estate)': 0.012,\n", " 'P(Internet_and_Telecom)': 0.016,\n", " 'P(Sports)': 0.015}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict('feline ear shaped headphones')" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Food_and_Drink \n", "probabilities_scores: 0.9909999966621399\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.02,\n", " 'P(News)': 0.017,\n", " 'P(Science)': 0.008,\n", " 'P(Autos_and_Vehicles)': 0.06,\n", " 'P(Health)': 0.032,\n", " 'P(Pets_and_Animals)': 0.004,\n", " 'P(Adult)': 0.022,\n", " 'P(Computers_and_Electronics)': 0.989,\n", " 'P(Online Communities)': 0.056,\n", " 'P(Beauty_and_Fitness)': 0.026,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.008,\n", " 'P(Reference)': 0.052,\n", " 'P(Shopping)': 0.105,\n", " 'P(Travel_and_Transportation)': 0.012,\n", " 'P(Food_and_Drink)': 0.991,\n", " 'P(Law_and_Government)': 0.007,\n", " 'P(Books_and_Literature)': 0.009,\n", " 'P(Finance)': 0.014,\n", " 'P(Games)': 0.284,\n", " 'P(Home_and_Garden)': 0.015,\n", " 'P(Jobs_and_Education)': 0.017,\n", " 'P(Arts_and_Entertainment)': 0.031,\n", " 'P(Sensitive Subjects)': 0.014,\n", " 'P(Real Estate)': 0.003,\n", " 'P(Internet_and_Telecom)': 0.003,\n", " 'P(Sports)': 0.021}" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"apple \")" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.054,\n", " 'P(News)': 0.003,\n", " 'P(Science)': 0.011,\n", " 'P(Autos_and_Vehicles)': 0.122,\n", " 'P(Health)': 0.01,\n", " 'P(Pets_and_Animals)': 0.004,\n", " 'P(Adult)': 0.054,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.081,\n", " 'P(Beauty_and_Fitness)': 0.016,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.005,\n", " 'P(Reference)': 0.064,\n", " 'P(Shopping)': 0.224,\n", " 'P(Travel_and_Transportation)': 0.006,\n", " 'P(Food_and_Drink)': 0.172,\n", " 'P(Law_and_Government)': 0.051,\n", " 'P(Books_and_Literature)': 0.006,\n", " 'P(Finance)': 0.025,\n", " 'P(Games)': 0.138,\n", " 'P(Home_and_Garden)': 0.03,\n", " 'P(Jobs_and_Education)': 0.006,\n", " 'P(Arts_and_Entertainment)': 0.008,\n", " 'P(Sensitive Subjects)': 0.003,\n", " 'P(Real Estate)': 0.006,\n", " 'P(Internet_and_Telecom)': 0.004,\n", " 'P(Sports)': 0.018}" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict('apple iphone')" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.077,\n", " 'P(News)': 0.005,\n", " 'P(Science)': 0.009,\n", " 'P(Autos_and_Vehicles)': 0.077,\n", " 'P(Health)': 0.015,\n", " 'P(Pets_and_Animals)': 0.003,\n", " 'P(Adult)': 0.073,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.086,\n", " 'P(Beauty_and_Fitness)': 0.022,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.004,\n", " 'P(Reference)': 0.021,\n", " 'P(Shopping)': 0.203,\n", " 'P(Travel_and_Transportation)': 0.003,\n", " 'P(Food_and_Drink)': 0.241,\n", " 'P(Law_and_Government)': 0.009,\n", " 'P(Books_and_Literature)': 0.003,\n", " 'P(Finance)': 0.029,\n", " 'P(Games)': 0.195,\n", " 'P(Home_and_Garden)': 0.044,\n", " 'P(Jobs_and_Education)': 0.004,\n", " 'P(Arts_and_Entertainment)': 0.013,\n", " 'P(Sensitive Subjects)': 0.003,\n", " 'P(Real Estate)': 0.012,\n", " 'P(Internet_and_Telecom)': 0.004,\n", " 'P(Sports)': 0.017}" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\n", " 'razer kraken'\n", ")" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Online Communities \n", "probabilities_scores: 0.9990000128746033\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.009,\n", " 'P(News)': 0.037,\n", " 'P(Science)': 0.014,\n", " 'P(Autos_and_Vehicles)': 0.004,\n", " 'P(Health)': 0.007,\n", " 'P(Pets_and_Animals)': 0.048,\n", " 'P(Adult)': 0.287,\n", " 'P(Computers_and_Electronics)': 0.536,\n", " 'P(Online Communities)': 0.999,\n", " 'P(Beauty_and_Fitness)': 0.002,\n", " 'P(People_and_Society)': 0.001,\n", " 'P(Business_and_Industrial)': 0.002,\n", " 'P(Reference)': 0.006,\n", " 'P(Shopping)': 0.038,\n", " 'P(Travel_and_Transportation)': 0.016,\n", " 'P(Food_and_Drink)': 0.012,\n", " 'P(Law_and_Government)': 0.024,\n", " 'P(Books_and_Literature)': 0.059,\n", " 'P(Finance)': 0.001,\n", " 'P(Games)': 0.025,\n", " 'P(Home_and_Garden)': 0.377,\n", " 'P(Jobs_and_Education)': 0.018,\n", " 'P(Arts_and_Entertainment)': 0.028,\n", " 'P(Sensitive Subjects)': 0.072,\n", " 'P(Real Estate)': 0.002,\n", " 'P(Internet_and_Telecom)': 0.003,\n", " 'P(Sports)': 0.006}" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"facebook\")" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.054,\n", " 'P(News)': 0.003,\n", " 'P(Science)': 0.011,\n", " 'P(Autos_and_Vehicles)': 0.122,\n", " 'P(Health)': 0.01,\n", " 'P(Pets_and_Animals)': 0.004,\n", " 'P(Adult)': 0.054,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.081,\n", " 'P(Beauty_and_Fitness)': 0.016,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.005,\n", " 'P(Reference)': 0.064,\n", " 'P(Shopping)': 0.224,\n", " 'P(Travel_and_Transportation)': 0.006,\n", " 'P(Food_and_Drink)': 0.172,\n", " 'P(Law_and_Government)': 0.051,\n", " 'P(Books_and_Literature)': 0.006,\n", " 'P(Finance)': 0.025,\n", " 'P(Games)': 0.138,\n", " 'P(Home_and_Garden)': 0.03,\n", " 'P(Jobs_and_Education)': 0.006,\n", " 'P(Arts_and_Entertainment)': 0.008,\n", " 'P(Sensitive Subjects)': 0.003,\n", " 'P(Real Estate)': 0.006,\n", " 'P(Internet_and_Telecom)': 0.004,\n", " 'P(Sports)': 0.018}" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict('apple iphone')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.186,\n", " 'P(News)': 0.003,\n", " 'P(Science)': 0.009,\n", " 'P(Autos_and_Vehicles)': 0.512,\n", " 'P(Health)': 0.002,\n", " 'P(Pets_and_Animals)': 0.002,\n", " 'P(Adult)': 0.039,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.061,\n", " 'P(Beauty_and_Fitness)': 0.003,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.001,\n", " 'P(Reference)': 0.015,\n", " 'P(Shopping)': 0.274,\n", " 'P(Travel_and_Transportation)': 0.002,\n", " 'P(Food_and_Drink)': 0.009,\n", " 'P(Law_and_Government)': 0.058,\n", " 'P(Books_and_Literature)': 0.002,\n", " 'P(Finance)': 0.033,\n", " 'P(Games)': 0.151,\n", " 'P(Home_and_Garden)': 0.027,\n", " 'P(Jobs_and_Education)': 0.002,\n", " 'P(Arts_and_Entertainment)': 0.005,\n", " 'P(Sensitive Subjects)': 0.001,\n", " 'P(Real Estate)': 0.035,\n", " 'P(Internet_and_Telecom)': 0.001,\n", " 'P(Sports)': 0.008}" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict('best vr')" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Computers_and_Electronics \n", "probabilities_scores: 1.0\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.186,\n", " 'P(News)': 0.003,\n", " 'P(Science)': 0.009,\n", " 'P(Autos_and_Vehicles)': 0.512,\n", " 'P(Health)': 0.002,\n", " 'P(Pets_and_Animals)': 0.002,\n", " 'P(Adult)': 0.039,\n", " 'P(Computers_and_Electronics)': 1.0,\n", " 'P(Online Communities)': 0.061,\n", " 'P(Beauty_and_Fitness)': 0.003,\n", " 'P(People_and_Society)': 0.0,\n", " 'P(Business_and_Industrial)': 0.001,\n", " 'P(Reference)': 0.015,\n", " 'P(Shopping)': 0.274,\n", " 'P(Travel_and_Transportation)': 0.002,\n", " 'P(Food_and_Drink)': 0.009,\n", " 'P(Law_and_Government)': 0.058,\n", " 'P(Books_and_Literature)': 0.002,\n", " 'P(Finance)': 0.033,\n", " 'P(Games)': 0.151,\n", " 'P(Home_and_Garden)': 0.027,\n", " 'P(Jobs_and_Education)': 0.002,\n", " 'P(Arts_and_Entertainment)': 0.005,\n", " 'P(Sensitive Subjects)': 0.001,\n", " 'P(Real Estate)': 0.035,\n", " 'P(Internet_and_Telecom)': 0.001,\n", " 'P(Sports)': 0.008}" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"best vr\")" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Adult \n", "probabilities_scores: 0.7149999737739563\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.684,\n", " 'P(News)': 0.009,\n", " 'P(Science)': 0.001,\n", " 'P(Autos_and_Vehicles)': 0.004,\n", " 'P(Health)': 0.001,\n", " 'P(Pets_and_Animals)': 0.0,\n", " 'P(Adult)': 0.715,\n", " 'P(Computers_and_Electronics)': 0.274,\n", " 'P(Online Communities)': 0.246,\n", " 'P(Beauty_and_Fitness)': 0.003,\n", " 'P(People_and_Society)': 0.001,\n", " 'P(Business_and_Industrial)': 0.0,\n", " 'P(Reference)': 0.0,\n", " 'P(Shopping)': 0.022,\n", " 'P(Travel_and_Transportation)': 0.001,\n", " 'P(Food_and_Drink)': 0.002,\n", " 'P(Law_and_Government)': 0.021,\n", " 'P(Books_and_Literature)': 0.007,\n", " 'P(Finance)': 0.003,\n", " 'P(Games)': 0.012,\n", " 'P(Home_and_Garden)': 0.178,\n", " 'P(Jobs_and_Education)': 0.002,\n", " 'P(Arts_and_Entertainment)': 0.01,\n", " 'P(Sensitive Subjects)': 0.001,\n", " 'P(Real Estate)': 0.026,\n", " 'P(Internet_and_Telecom)': 0.0,\n", " 'P(Sports)': 0.02}" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\"pa best views\")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class: Hobbies_and_Leisure \n", "probabilities_scores: 0.9950000047683716\n", "\n" ] }, { "data": { "text/plain": [ "{'P(Hobbies_and_Leisure)': 0.995,\n", " 'P(News)': 0.003,\n", " 'P(Science)': 0.008,\n", " 'P(Autos_and_Vehicles)': 0.026,\n", " 'P(Health)': 0.003,\n", " 'P(Pets_and_Animals)': 0.003,\n", " 'P(Adult)': 0.075,\n", " 'P(Computers_and_Electronics)': 0.127,\n", " 'P(Online Communities)': 0.156,\n", " 'P(Beauty_and_Fitness)': 0.026,\n", " 'P(People_and_Society)': 0.001,\n", " 'P(Business_and_Industrial)': 0.0,\n", " 'P(Reference)': 0.0,\n", " 'P(Shopping)': 0.046,\n", " 'P(Travel_and_Transportation)': 0.003,\n", " 'P(Food_and_Drink)': 0.002,\n", " 'P(Law_and_Government)': 0.041,\n", " 'P(Books_and_Literature)': 0.012,\n", " 'P(Finance)': 0.011,\n", " 'P(Games)': 0.002,\n", " 'P(Home_and_Garden)': 0.062,\n", " 'P(Jobs_and_Education)': 0.003,\n", " 'P(Arts_and_Entertainment)': 0.029,\n", " 'P(Sensitive Subjects)': 0.0,\n", " 'P(Real Estate)': 0.146,\n", " 'P(Internet_and_Telecom)': 0.0,\n", " 'P(Sports)': 0.007}" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(\n", " \"best ac dharmashala in vrindavan\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "inputs = tokenizer(\"best cat ear headphones\", return_tensors=\"pt\")\n", "with torch.no_grad():\n", " logits = model(**inputs).logits" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-1.353771 , -5.8301578, -4.050355 , -1.9018538, -5.129807 ,\n", " -5.2707334, -2.696651 , 8.821061 , -2.0982835, -4.4173856,\n", " -9.076361 , -5.888918 , -3.7155762, -1.0305756, -5.5817475,\n", " -3.987473 , -2.4096951, -5.1136127, -3.217719 , -2.938894 ,\n", " -3.7113686, -5.8976064, -4.788314 , -6.4181705, -3.5685277,\n", " -4.5266075, -4.3206973], dtype=float32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l= logits.numpy()[0]\n", "l" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# logit2prob <- function(logit){\n", "# odds <- exp(logit)\n", "# prob <- odds / (1 + odds)\n", "# return(prob)\n", "# }\n", "def logit2prob(logit):\n", " odds =np.exp(logit)\n", " prob = odds / (1 + odds)\n", " return np.round(prob, 2)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.21\n", "0.0\n", "0.02\n", "0.13\n", "0.01\n", "0.01\n", "0.06\n", "1.0\n", "0.11\n", "0.01\n", "0.0\n", "0.0\n", "0.02\n", "0.26\n", "0.0\n", "0.02\n", "0.08\n", "0.01\n", "0.04\n", "0.05\n", "0.02\n", "0.0\n", "0.01\n", "0.0\n", "0.03\n", "0.01\n", "0.01\n" ] } ], "source": [ "for i in l:\n", " print(round(logit2prob(i), 2))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.21, 0. , 0.02, 0.13, 0.01, 0.01, 0.06, 1. , 0.11, 0.01, 0. ,\n", " 0. , 0.02, 0.26, 0. , 0.02, 0.08, 0.01, 0.04, 0.05, 0.02, 0. ,\n", " 0.01, 0. , 0.03, 0.01, 0.01], dtype=float32)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit2prob(l)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }