{ "cells": [ { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Imports\n", "from PIL import Image, ImageFilter\n", "import numpy as np\n", "from transformers import pipeline\n", "import gradio as gr\n", "import os\n", "\n", "model = pipeline(\"image-segmentation\", model=\"facebook/detr-resnet-50-panoptic\")\n", "\n", "pred = []\n", "\n", "def img_resize(image):\n", " width = 1000\n", " width_percent = (width / float(image.size[0]))\n", " height = int((float(image.size[1]) * float(width_percent)))\n", " return image.resize((width, height))\n", "\n", "def image_objects(image):\n", " global pred\n", " image = img_resize(image)\n", " pred = model(image)\n", " pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)]\n", " return gr.Dropdown.update(choices = pred_object_list, interactive = True)\n", "\n", "def blurr_object(image, object, blur_strength):\n", " image = img_resize(image)\n", "\n", " object_number = int(object.split('_')[0])\n", " mask_array = np.asarray(pred[object_number]['mask'])/255\n", " image_array = np.asarray(image)\n", "\n", " mask_array_three_channel = np.zeros_like(image_array)\n", " mask_array_three_channel[:,:,0] = mask_array\n", " mask_array_three_channel[:,:,1] = mask_array\n", " mask_array_three_channel[:,:,2] = mask_array\n", "\n", " segmented_image = image_array*mask_array_three_channel\n", "\n", " blur_image = np.asarray(image.filter(ImageFilter.GaussianBlur(radius=blur_strength)))\n", " mask_array_three_channel_invert = 1-mask_array_three_channel\n", " blur_image_reverse_mask = blur_image*mask_array_three_channel_invert\n", "\n", " blurred_output_image = Image.fromarray((blur_image_reverse_mask).astype(np.uint8)+segmented_image.astype(np.uint8))\n", "\n", " return blurred_output_image\n", "\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7862/\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "