File size: 2,490 Bytes
07a421e
 
23dca80
07a421e
 
d2d56e8
b5b4791
 
 
07a421e
 
 
 
 
a960bc2
 
b5b4791
0a4d4a6
c66e22e
7efeba9
07a421e
 
 
 
b5b4791
 
0a4d4a6
b5b4791
 
 
 
 
a59bcf0
07a421e
 
 
a960bc2
 
 
b5b4791
a960bc2
b5b4791
a960bc2
 
b5b4791
a960bc2
 
 
07a421e
23dca80
07a421e
 
 
 
 
 
 
 
 
 
 
b5b4791
07a421e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from diffusers import FluxPipeline
from transformers import pipeline
import gradio as gr
import spaces


device=torch.device('cuda')

# Load the model and LoRA weights
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
pipe.fuse_lora(lora_scale=1.5)
pipe.to("cuda")

# Load the NSFW classifier
image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=device)
text_classifier = pipeline("text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier",device=device)

NSFW_THRESHOLD = 0.5

# Define the function to generate the sketch
@spaces.GPU
def generate_sketch(prompt, num_inference_steps, guidance_scale):
    # Classify the text for NSFW content
    text_classification = text_classifier(prompt)
    print(text_classification)
    # Check the classification results
    for result in text_classification:
        if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD: 
            return "Inappropriate prompt detected. Please try another prompt."
    
    image = pipe("sketched style, " + prompt, 
                 num_inference_steps=num_inference_steps, 
                 guidance_scale=guidance_scale,
                ).images[0]
    

    # Classify the image for NSFW content
    image_classification = image_classifier(image)

    print(image_classification)

    # Check the classification results
    for result in image_classification:
        if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD: 
            return "Inappropriate content detected. Please try another prompt."
    
    image_path = "generated_sketch.png"
    
    image.save(image_path)
    return image_path

# Gradio interface with sliders for num_inference_steps and guidance_scale
interface = gr.Interface(
    fn=generate_sketch,
    inputs=[
        "text",  # Prompt input
        gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"),  # Slider for num_inference_steps
        gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale")  # Slider for guidance_scale
    ],
    outputs="auto",
    title="Kids Sketch Generator",
    description="Enter a text prompt and generate a fun sketch for kids with customizable inference steps and guidance scale."
)

# Launch the app
interface.launch()