yeftakun commited on
Commit
d9ce9e6
·
verified ·
1 Parent(s): 9a00b26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import ViTImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+
7
+ # Load the model and processor
8
+ processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
9
+ model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
10
+
11
+ # Define prediction function
12
+ def predict_image(image):
13
+ try:
14
+ # Process the image and make prediction
15
+ inputs = processor(images=image, return_tensors="pt")
16
+ outputs = model(**inputs)
17
+ logits = outputs.logits
18
+
19
+ # Get predicted class
20
+ predicted_class_idx = logits.argmax(-1).item()
21
+ predicted_label = model.config.id2label[predicted_class_idx]
22
+
23
+ return predicted_label
24
+ except Exception as e:
25
+ return str(e)
26
+
27
+ # Streamlit app
28
+ st.title("NSFW Image Classifier")
29
+
30
+ # Upload image file
31
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
32
+ if uploaded_file is not None:
33
+ image = Image.open(uploaded_file)
34
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
35
+ st.write("")
36
+ st.write("Classifying...")
37
+
38
+ # Predict and display result
39
+ prediction = predict_image(image)
40
+ st.write(f"Predicted Class: {prediction}")