Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,632 Bytes
07a421e 23dca80 07a421e d2d56e8 b5b4791 07a421e a960bc2 b5b4791 0a4d4a6 c66e22e 7efeba9 07a421e b5b4791 dd3728f 7ca8bcd b5b4791 7ca8bcd b5b4791 a59bcf0 07a421e a960bc2 b5b4791 a960bc2 b5b4791 a960bc2 b5b4791 a960bc2 7ca8bcd a960bc2 07a421e 23dca80 07a421e 7ca8bcd 07a421e 7ca8bcd 07a421e 7ca8bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
from diffusers import FluxPipeline
from transformers import pipeline
import gradio as gr
import spaces
device=torch.device('cuda')
# Load the model and LoRA weights
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
pipe.fuse_lora(lora_scale=1.5)
pipe.to("cuda")
# Load the NSFW classifier
image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=device)
text_classifier = pipeline("text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier",device=device)
NSFW_THRESHOLD = 0.5
# Define the function to generate the sketch
@spaces.GPU
def generate_sketch(prompt, num_inference_steps, guidance_scale):
# Classify the text for NSFW content
text_classification = text_classifier(prompt)
print(text_classification)
# Check the classification results
for result in text_classification:
if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
image = pipe("sketched style, " + prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
# Classify the image for NSFW content
image_classification = image_classifier(image)
print(image_classification)
# Check the classification results
for result in image_classification:
if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
return gr.update(visible=False),gr.Text(value="Inappropriate content detected. Please try another prompt.")
image_path = "generated_sketch.png"
image.save(image_path)
return gr.Image(value=image_path), gr.update(visible=False)
# Gradio interface with sliders for num_inference_steps and guidance_scale
interface = gr.Interface(
fn=generate_sketch,
inputs=[
"text", # Prompt input
gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"), # Slider for num_inference_steps
gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale") # Slider for guidance_scale
],
outputs=[gr.Image(), gr.Text()],
title="Kids Sketch Generator",
description="Enter a text prompt and generate a fun sketch for kids with customizable inference steps and guidance scale."
)
# Launch the app
interface.launch() |