rainbow_media_x / app.py
panelforge's picture
Update app.py
0f8e37d verified
raw
history blame
4.58 kB
import gradio as gr
import numpy as np
import random
import spaces
from diffusers import DiffusionPipeline, DPMSolverSDEScheduler
import torch
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
from PIL import Image
import cv2
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
adetailer_model_id = "Bingsu/adetailer" # Your ADetailer model
# Load the YOLO model for face detection
yolo_model_path = hf_hub_download(adetailer_model_id, "face_yolov8n.pt")
yolo_model = YOLO(yolo_model_path)
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe.scheduler = DPMSolverSDEScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def correct_anime_face(image):
# Convert to OpenCV format
img = np.array(image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Detect faces
results = yolo_model(img)
for detection in results[0].boxes:
x1, y1, x2, y2 = map(int, detection.xyxy[0].tolist())
# Crop the face region
face = img[y1:y2, x1:x2]
face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
# Prompt for the correction model
prompt = "Enhance this anime character's face, fix eyes and make features more vivid."
# Process the face with the anime correction model
corrected_face = pipe(prompt=prompt, image=face_pil).images[0] # Replace with your correction model
# Place the corrected face back into the original image
img[y1:y2, x1:x2] = np.array(corrected_face)
# Convert back to PIL
final_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return final_image
@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
# Correct anime face in the generated image
corrected_image = correct_anime_face(image)
return corrected_image, seed
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Text-to-Image Gradio Template")
with gr.Row():
prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", visible=False)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=2)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed])
demo.queue().launch()