Annas Dev commited on
Commit
cf69c91
·
1 Parent(s): d4d8937
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ venv/
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import random
4
+ from src.model import simlarity_model as model
5
+ from src.similarity.similarity import Similarity
6
+
7
+ similarity = Similarity()
8
+ models = similarity.get_models()
9
+
10
+ def check(img_main, img_1, img_2, model_idx):
11
+ images = [
12
+ (random.choice(
13
+ [
14
+ "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
15
+ "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
16
+ "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
17
+ "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
18
+ "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
19
+ ]
20
+ ), f"label {i}" if i != 0 else "label" * 50)
21
+ for i in range(3)
22
+ ]
23
+ similarity.check_similarity([img_main, img_1, img_2], models[model_idx])
24
+ return images
25
+
26
+ # def greet(name):
27
+ # return "Hello " + name + "!!"
28
+
29
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
30
+ # iface.launch()
31
+
32
+ with gr.Blocks() as demo:
33
+ gr.Markdown('Checking Image Similarity')
34
+ img_main = gr.Text(label='Main Image', placeholder='https://myimage.jpg')
35
+
36
+ gr.Markdown('Images to check')
37
+ img_1 = gr.Text(label='1st Image', placeholder='https://myimage_1.jpg')
38
+ img_2 = gr.Text(label='2nd Image', placeholder='https://myimage_2.jpg')
39
+
40
+ gr.Markdown('Choose the model')
41
+ model = gr.Dropdown([m.name for m in models], label='Model', type='index')
42
+
43
+ gallery = gr.Gallery(
44
+ label="Generated images", show_label=True, elem_id="gallery"
45
+ ).style(grid=[2], height="auto")
46
+
47
+ submit_btn = gr.Button('Check Similarity')
48
+ submit_btn.click(fn=check,inputs=[img_main, img_1, img_2, model], outputs=gallery)
49
+
50
+ demo.launch()
src/model/similarity_interface.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class SimilarityInterface:
2
+ def extract_feature(img):
3
+ return []
src/model/simlarity_model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from .similarity_interface import SimilarityInterface
3
+
4
+ @dataclass
5
+ class SimilarityModel:
6
+ name: str
7
+ image_size: int
8
+ model_cls: SimilarityInterface
src/similarity/model_implements/mobilenet_v3.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class ModelnetV3():
2
+ def extract_feature(self, img):
3
+ print('getting with ModelnetV3...')
4
+ return []
src/similarity/similarity.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model import simlarity_model as model
2
+ from src.util import image as image_util
3
+ from .model_implements.mobilenet_v3 import ModelnetV3
4
+
5
+ class Similarity:
6
+ def get_models(self):
7
+ return [model.SimilarityModel(name= 'mobilenet_v3', image_size= 224, model_cls = ModelnetV3())]
8
+
9
+ def check_similarity(self, img_urls, model):
10
+ # model = self.get_models()[model_idx]
11
+ imgs = []
12
+ for url in img_urls:
13
+ if url == "": continue
14
+ imgs.append(image_util.load_image_url(url, required_size=(model.image_size, model.image_size)))
15
+ model.model_cls.extract_feature(imgs[0])
16
+ return 'oke'
17
+
18
+
src/util/image.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import requests
4
+
5
+ def load_image_url(url, required_size = (224,224)):
6
+ img = Image.open(requests.get(url, stream=True).raw)
7
+ img = Image.fromarray(np.array(img))
8
+ img = img.resize(required_size)
9
+ img = (np.expand_dims(np.array(img), 0)/255).astype(np.float32)
10
+ return img