Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,753 Bytes
07a421e 23dca80 07a421e d2d56e8 b5b4791 4ef16a2 ab57896 a960bc2 b5b4791 7785249 c66e22e 85913ad 07a421e 9388d07 b5b4791 7785249 7ca8bcd b5b4791 7785249 ab57896 4ef16a2 e1df300 01333b3 ab57896 4ef16a2 9388d07 a59bcf0 07a421e a960bc2 b5b4791 a960bc2 b5b4791 a960bc2 b5b4791 a960bc2 4715e52 a960bc2 07a421e 23dca80 07a421e 4715e52 07a421e 9388d07 01333b3 9388d07 07a421e 7785249 4715e52 7785249 07a421e 7ca8bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from diffusers import FluxPipeline
from transformers import pipeline
import gradio as gr
import spaces
device=torch.device('cuda')
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load the NSFW classifier
image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=device)
#text_classifier = pipeline("text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier",device=device)
NSFW_THRESHOLD = 0.3
# Define the function to generate the sketch
@spaces.GPU
def generate_sketch(prompt,style, num_inference_steps, guidance_scale):
# Classify the text for NSFW content
#text_classification = text_classifier(prompt)
#print(text_classification)
# Check the classification results
#for result in text_classification:
# if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
# return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
print(prompt)
match style:
case 'shou_xin':
prompt = "shou_xin, " + prompt
pipe.load_lora_weights("Datou1111/shou_xin", weight_name="shou_xin.safetensors")
case 'sketched':
prompt = "sketched style, " + prompt
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
case 'sketch_paint':
prompt = "Sketch paint, " + prompt
pipe.load_lora_weights("strangerzonehf/Sketch-Paint", weight_name="Sketch-Paint.safetensors")
case 'sketch_sized':
prompt = "Sketch Sized, " + prompt
pipe.load_lora_weights("strangerzonehf/Flux-Sketch-Sized-LoRA", weight_name="Flux-Sketch-Sized-LoRA.safetensors")
case _:
prompt = "shou_xin, " + prompt
pipe.load_lora_weights("Datou1111/shou_xin", weight_name="shou_xin.safetensors")
pipe.fuse_lora(lora_scale=1.5)
image = pipe("sketched style, " + prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
# Classify the image for NSFW content
image_classification = image_classifier(image)
print(image_classification)
# Check the classification results
for result in image_classification:
if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
return None,"Inappropriate content detected. Please try another prompt." #return gr.update(visible=False),gr.Text(value="Inappropriate content detected. Please try another prompt.")
image_path = "generated_sketch.png"
image.save(image_path)
return image_path,None #gr.Image(value=image_path), gr.update(visible=False)
# Gradio interface with sliders for num_inference_steps and guidance_scale
interface = gr.Interface(
fn=generate_sketch,
inputs=[
"text", # Prompt input
gr.Dropdown(
["sketched", "shou_xin","Sketch-Paint","Flux-Sketch-Sized-LoRA"],value='sketched', label="Style"
),
gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"), # Slider for num_inference_steps
gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale") # Slider for guidance_scale
],
outputs=[
gr.Image(label="Generated Sketch"),
gr.Textbox(label="Message")
],
title="Kids Sketch Generator",
description="Enter a text prompt and generate a fun sketch for kids with customizable inference steps and guidance scale."
)
# Launch the app
interface.launch() |