Neel kamal sahu commited on
Commit
8d424e9
1 Parent(s): 1cd3419

trying to add bars in output

Browse files
Files changed (2) hide show
  1. app.py +6 -11
  2. my_model.h5 +2 -2
app.py CHANGED
@@ -6,7 +6,6 @@ from tensorflow.keras.models import load_model
6
 
7
  # Load the pre-trained model
8
  model = load_model('my_model.h5')
9
- CATEGORIES = ("NSFW", "SFW")
10
  def classify_image(img):
11
  # Preprocess the input image
12
  img = image.img_to_array(img)
@@ -14,15 +13,11 @@ def classify_image(img):
14
  img /= 255.0
15
 
16
  # Use the model to make a prediction
17
- prediction = model.predict(img)
18
- print(prediction)
19
  # Map the predicted class to a label
20
- if prediction[0][0] >= 0.5:
21
- label = "SFW"
22
- else:
23
- label = "NSFW"
24
-
25
- return label
26
 
27
  def classify_url(url):
28
  # Load the image from the URL
@@ -35,10 +30,10 @@ def classify_url(url):
35
  #inputs = gr.inputs.Image(shape=(224, 224, 3))
36
 
37
  # Define the GRADIO output interface
38
- #outputs = gr.outputs.Textbox(lines=1, default="SFW")
39
 
40
  # Define the GRADIO app
41
- app = gr.Interface(classify_image, gr.Image(shape=(224, 224)), outputs=gr.Label(label="Type of Image"), allow_flagging="never", title="NSFW/SFW Classifier")
42
 
43
  # Start the GRADIO app
44
  app.launch()
 
6
 
7
  # Load the pre-trained model
8
  model = load_model('my_model.h5')
 
9
  def classify_image(img):
10
  # Preprocess the input image
11
  img = image.img_to_array(img)
 
13
  img /= 255.0
14
 
15
  # Use the model to make a prediction
16
+ prediction = model.predict(img)[0]
17
+ #print(prediction)
18
  # Map the predicted class to a label
19
+ dic = {'SFW': np.round(prediction[1],2), 'NSFW': np.round(prediction[0],2)}
20
+ return dic#{'SFW': prediction[0][1], 'NSFW': prediction[0][0]}
 
 
 
 
21
 
22
  def classify_url(url):
23
  # Load the image from the URL
 
30
  #inputs = gr.inputs.Image(shape=(224, 224, 3))
31
 
32
  # Define the GRADIO output interface
33
+
34
 
35
  # Define the GRADIO app
36
+ app = gr.Interface(classify_image, gr.Image(shape=(224, 224)), outputs="label", allow_flagging="never", title="NSFW/SFW Classifier")
37
 
38
  # Start the GRADIO app
39
  app.launch()
my_model.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:72f973a8a272937a04f5ea32eb48539a61338694d310a7abe805e89de29f4ff4
3
- size 94864144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff20523d0b47e5980145186785b93d4ebb2fa858759ab9c7bb452476955b1663
3
+ size 189463232