File size: 2,672 Bytes
ea673b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from flask import Flask, request, jsonify
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from PIL import Image
import torch
import base64
from io import BytesIO
from huggingface_hub import login

# Authenticate with Hugging Face Hub (ensure you replace 'your_token_here')
import os
login(os.environ["HF_TOKEN"])

# Initialize Flask app
app = Flask(__name__)

# Load Hugging Face pipeline components
model_id = "fyp1/sketchToImage"
controlnet = ControlNetModel.from_pretrained(f"{model_id}/controlnet", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}/scheduler")

# Initialize Stable Diffusion XL ControlNet Pipeline
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    safety_checker=None,
    torch_dtype=torch.float16,
).to("cuda" if torch.cuda.is_available() else "cpu")

@app.route("/generate", methods=["POST"])
def generate_image():
    data = request.json

    # Extract prompt, sketch image (Base64), and optional parameters
    prompt = data.get("prompt", "A default prompt")
    negative_prompt = data.get("negative_prompt", "low quality, blurry, bad details")
    sketch_base64 = data.get("sketch", None)

    if not sketch_base64:
        return jsonify({"error": "Sketch image is required."}), 400

    try:
        # Decode and preprocess the sketch image
        sketch_bytes = base64.b64decode(sketch_base64)
        sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L")  # Convert to grayscale
        sketch_image = sketch_image.resize((1024, 1024))

        # Generate the image using the pipeline
        with torch.no_grad():
            images = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                image=sketch_image,
                controlnet_conditioning_scale=1.0,
                width=1024,
                height=1024,
                num_inference_steps=30,
            ).images

        # Convert output image to Base64
        buffered = BytesIO()
        images[0].save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return jsonify({"image": image_base64})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)