{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a3278dc9-0d83-4a37-aece-e46ac416988f",
"metadata": {},
"outputs": [],
"source": [
"#| default_exp app"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d6810835-d62a-4f94-a52e-0e0cd163fb98",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from fastai.vision.all import *\n",
"import gradio as gr\n",
"\n",
"interpretation='default'\n",
"enable_queue=True\n",
"\n",
"title = \"FastAI - Big Cats Classifier\"\n",
"description = \"Classify big cats using all Resnet models available pre-trained in FastAI\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6092ad61-d5cd-40f7-b2d2-20a77b0c8b0f",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"learners = {\n",
" \"resnet-18\" : 'models/resnet18-model.pkl',\n",
" \"resnet-34\" : 'models/resnet34-model.pkl',\n",
" \"resnet-50\" : 'models/resnet50-model.pkl',\n",
" \"resnet-101\": 'models/resnet101-model.pkl',\n",
" \"resnet-152\": 'models/resnet152-model.pkl'\n",
"}\n",
"models = list(learners.keys())\n",
"\n",
"active_name = \"resnet-18\"\n",
"active_model = learners[active_name]\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "632cbc1b-73b5-4992-8956-d4ae40f6b80b",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
" \n",
"def classify_image(img):\n",
" learn = load_learner(active_model)\n",
" pred,idx,probs = learn.predict(img)\n",
" return dict(zip(learn.dls.vocab, map(float, probs)))\n",
"\n",
"def select_model(model_name):\n",
" if model_name not in models:\n",
" model_name = \"resnet-18\"\n",
" active_name = model_name\n",
" active_model = learners[active_name]\n",
" return model_name.upper()\n",
"\n",
"def update_matrix():\n",
" return \"models/\" + active_name.replace('-','',1) + \"-confusion-matrix.png\"\n",
" \n",
"def update_losses():\n",
" return \"models/\" + active_name.replace('-','',1) + \"-top-losses.png\"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9b5f1cc6-5173-475a-9365-0cab11db2d03",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 0.00045245178625918925, 'cheetah': 0.9994743466377258, 'clouded leopard': 3.061432778395101e-07, 'cougar': 8.726581654627807e-06, 'jaguar': 4.878858817392029e-05, 'lion': 1.4129628652881365e-05, 'snow leopard': 1.2738197483486147e-06, 'tiger': 1.1983513736879559e-08}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 8.918660228118824e-07, 'cheetah': 3.004239079729132e-09, 'clouded leopard': 1.0275688282490592e-06, 'cougar': 1.8215871477877954e-08, 'jaguar': 0.9999979734420776, 'lion': 7.327587425720594e-10, 'snow leopard': 1.3988608316140017e-07, 'tiger': 4.418302523845341e-08}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 1.279351291572084e-08, 'cheetah': 3.040315732505405e-08, 'clouded leopard': 4.387358387702989e-08, 'cougar': 1.2642824458453106e-06, 'jaguar': 3.0061545430726255e-07, 'lion': 2.5054502472698914e-08, 'snow leopard': 4.821659516096588e-08, 'tiger': 0.9999983310699463}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 2.2317146886052797e-06, 'cheetah': 6.153353297122521e-06, 'clouded leopard': 3.5761433991865488e-06, 'cougar': 0.9940788745880127, 'jaguar': 7.271950153153739e-08, 'lion': 0.005906379781663418, 'snow leopard': 1.0360908220263809e-07, 'tiger': 2.569006483099656e-06}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 7.383512135028525e-10, 'cheetah': 1.6924343526625307e-06, 'clouded leopard': 3.8847122740826023e-10, 'cougar': 1.4941306858418102e-08, 'jaguar': 3.277633942033731e-09, 'lion': 0.9999983310699463, 'snow leopard': 4.2623696572263725e-08, 'tiger': 5.7686470711360016e-08}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 0.11080536246299744, 'cheetah': 0.00025237080990336835, 'clouded leopard': 0.0003655211767181754, 'cougar': 1.1126862773380708e-05, 'jaguar': 0.8603838086128235, 'lion': 8.311066630994901e-05, 'snow leopard': 0.028046416118741035, 'tiger': 5.234780110185966e-05}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 5.991949336703328e-08, 'cheetah': 1.2888077272066312e-08, 'clouded leopard': 0.9999984502792358, 'cougar': 7.355600928349304e-07, 'jaguar': 5.131531679580803e-07, 'lion': 5.543293823961903e-09, 'snow leopard': 3.404375448212704e-08, 'tiger': 2.0324510785485472e-07}\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'african leopard': 2.2017589799361303e-05, 'cheetah': 9.802879503695294e-05, 'clouded leopard': 0.0109814228489995, 'cougar': 1.8166520021623e-06, 'jaguar': 5.0095695769414306e-06, 'lion': 5.28784084963263e-06, 'snow leopard': 0.988881528377533, 'tiger': 4.889693173026899e-06}\n"
]
}
],
"source": [
"example_images = [ 'cheetah.jpg', 'jaguar.jpg', 'tiger.jpg', 'cougar.jpg', 'lion.jpg', 'african leopard.jpg', 'clouded leopard.jpg', 'snow leopard.jpg' ]\n",
"\n",
"for c in example_images:\n",
" im = PILImage.create(c)\n",
" result = classify_image(im)\n",
" print(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a48e7483-c04b-4048-a1ae-34a8c7986a57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://127.0.0.1:7860\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#| export\n",
"example_images = [ 'cheetah.jpg', 'jaguar.jpg', 'tiger.jpg', 'cougar.jpg', 'lion.jpg', 'african leopard.jpg', 'clouded leopard.jpg', 'snow leopard.jpg', 'hidden.png', 'hidden2.png' ]\n",
"\n",
"demo = gr.Blocks()\n",
"with demo:\n",
" with gr.Column(variant=\"panel\"):\n",
" image = gr.inputs.Image(label=\"Pick an image\")\n",
" model = gr.inputs.Dropdown(label=\"Select a model\", choices=models)\n",
" with gr.Row(equal_height=True):\n",
" btnClassify = gr.Button(\"Classify\")\n",
" btnClear = gr.Button(\"Clear\")\n",
" with gr.Column(variant=\"panel\"):\n",
" selected = gr.outputs.Textbox(label=\"Active Model\")\n",
" with gr.Row(equal_height=True):\n",
" matrix=gr.outputs.Image(type='filepath', label=\"Confusion Matrix\")\n",
" losses=gr.outputs.Image(type='filepath', label=\"Top Losses\")\n",
" result = gr.outputs.Label(label=\"Result\")\n",
" \n",
" img_gallery = gr.Examples(examples=example_images, inputs=image)\n",
"\n",
" # Register all event listeners\n",
" model.change(fn=select_model, inputs=model, outputs=selected)\n",
" model.change(fn=update_matrix, outputs=matrix)\n",
" model.change(fn=update_losses, outputs=losses)\n",
" btnClassify.click(fn=classify_image, inputs=image, outputs=result)\n",
" btnClear.click(fn=lambda: gr.Image.update(value=None), inputs=None, outputs=None)\n",
"\n",
"demo.launch(debug=True, inline=False)\n",
" # intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=example_images, title=title, description=description )\n",
" # if __name__ == \"__main__\":\n",
" # intf.launch(debug=True, inline=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cab071f9-7c3b-4b35-a0d1-3687731ffce5",
"metadata": {},
"outputs": [],
"source": [
"import nbdev\n",
"nbdev.export.nb_export('app.ipynb', './')\n",
"print('Export successful')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95f0e7ec-edd2-4afa-a68f-7da8b85b1f61",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}