File size: 4,576 Bytes
630d1c8
 
 
0f8e37d
cce056a
630d1c8
f13995d
 
 
0f8e37d
630d1c8
 
0f8e37d
 
 
 
 
 
 
 
 
 
 
630d1c8
0f8e37d
85f4074
630d1c8
 
 
 
 
0f8e37d
 
f13995d
 
357c194
0f8e37d
f13995d
357c194
0f8e37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f13995d
0f8e37d
 
 
357c194
0f8e37d
f263a5c
630d1c8
 
e3e7392
630d1c8
 
 
357c194
 
 
 
 
 
 
e3e7392
0f8e37d
 
 
630d1c8
357c194
630d1c8
 
 
 
 
 
 
0f8e37d
 
 
 
 
 
630d1c8
 
 
0f8e37d
630d1c8
 
0f8e37d
630d1c8
 
 
 
 
0f8e37d
 
630d1c8
 
 
0f8e37d
 
e3e7392
630d1c8
0f8e37d
 
ce6ba71
0f8e37d
ce6ba71
0f8e37d
 
 
 
ce6ba71
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import numpy as np
import random
import spaces
from diffusers import DiffusionPipeline, DPMSolverSDEScheduler
import torch
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
from PIL import Image
import cv2

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
adetailer_model_id = "Bingsu/adetailer"  # Your ADetailer model

# Load the YOLO model for face detection
yolo_model_path = hf_hub_download(adetailer_model_id, "face_yolov8n.pt")
yolo_model = YOLO(yolo_model_path)

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe.scheduler = DPMSolverSDEScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True)
pipe = pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

def correct_anime_face(image):
    # Convert to OpenCV format
    img = np.array(image)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    
    # Detect faces
    results = yolo_model(img)
    
    for detection in results[0].boxes:
        x1, y1, x2, y2 = map(int, detection.xyxy[0].tolist())
        
        # Crop the face region
        face = img[y1:y2, x1:x2]
        face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
        
        # Prompt for the correction model
        prompt = "Enhance this anime character's face, fix eyes and make features more vivid."
        
        # Process the face with the anime correction model
        corrected_face = pipe(prompt=prompt, image=face_pil).images[0]  # Replace with your correction model
        
        # Place the corrected face back into the original image
        img[y1:y2, x1:x2] = np.array(corrected_face)
    
    # Convert back to PIL
    final_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return final_image

@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt=prompt, 
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps, 
        width=width, 
        height=height,
        generator=generator
    ).images[0] 
    
    # Correct anime face in the generated image
    corrected_image = correct_anime_face(image)
    
    return corrected_image, seed

examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Text-to-Image Gradio Template")
        
        with gr.Row():
            prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", visible=False)
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
                height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
                
            with gr.Row():
                guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
                num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=2)
        
        gr.Examples(examples=examples, inputs=[prompt])
    
    gr.on(triggers=[run_button.click, prompt.submit],
          fn=infer,
          inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
          outputs=[result, seed])

demo.queue().launch()