import gradio as gr from transformers import ImageClassificationPipeline, AutoImageProcessor, AutoModelForImageClassification, ResNetForImageClassification # # import torch from transformers import pipeline feature_extractor = AutoImageProcessor.from_pretrained("artfan123/resnet-18-finetuned-ai-art") model = AutoModelForImageClassification.from_pretrained("artfan123/resnet-18-finetuned-ai-art") image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor) def classify_image(image): results = image_pipe(image) # convert to format Gradio expects output = {} for prediction in results: predicted_label = prediction['label'] score = prediction['score'] output[predicted_label] = score return output image = gr.inputs.Image(type="pil") label = gr.outputs.Label(num_top_classes=2) examples = [['50.jpg'], ['344.jpg'],['24.jpg'], ['339.jpg'], ['105.jpg']] title = "AI Art Detector" description = "A deep learning model that detects whether an image is AI generated or human made. Upload image or use the example images below." gr.Interface(fn=classify_image, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True) # if __name__ == "__main__": # with gr.Blocks() as demo: # with gr.Row(): # with gr.Column(scale=4.5): # with gr.Group(): # image_prompt = gr.Image(type='pil', shape=[512,512],label="Input Image") # gr.Examples(inputs=image_prompt,examples=[['50.jpg'], ['344.jpg'],['24.jpg'], ['339.jpg'], ['105.jpg']]) # with gr.Row(): # clear_button = gr.Button('Clear') # run_button = gr.Button('Predict') # with gr.Column(scale=5.5): # image_output = gr.Image(type='pil', shape=[512,512], label="Prediction") # clear_button.click(lambda: None, None, image_prompt, queue=False) # clear_button.click(lambda: None, None, image_output, queue=False) # run_button.click(fn=segment,inputs=[image_prompt], # outputs=[image_output]) # demo.queue().launch(share=True)