satya-demo-v0 / app.py
mrsarthakgupta's picture
Update app.py
85d6695 verified
raw
history blame contribute delete
No virus
1.31 kB
import gradio as gr
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, CLIPVisionModel
import torch
from PIL import Image
import json
import numpy as np
# Load pre-trained model and feature extractor
model_name = "mrsarthakgupta/openclipvisionmodel"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = CLIPVisionModel.from_pretrained(model_name)
matrix = torch.Tensor(np.array(json.load(open('linear_params.json', 'r'))["0"]))
def classify_image(image):
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
results = torch.matmul(outputs.pooler_output.squeeze(), matrix.squeeze())
return {'Our confidence of this image being AI generated': torch.nn.functional.sigmoid(results + torch.Tensor([0.9791784]))}
# Create Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=1),
title="Detect AI generated image",
description="Upload an image to find the probability of it being AI generated. While we're certainly not perfect, but we're working really hard to get a more accurate classifier!"
)
# Launch the app
iface.launch()