File size: 3,184 Bytes
c513221
5e2c7ed
6b1b953
192bf4d
5e2c7ed
4294a68
86743ba
61d26c4
 
 
 
 
 
 
 
 
5d9bf5a
4294a68
61d26c4
eb48f29
6b1b953
 
 
4294a68
c07573a
4294a68
6b1b953
6292767
4294a68
6292767
 
ff2ee2b
eb48f29
 
4294a68
eb48f29
4294a68
ff2ee2b
412c18a
4294a68
eb48f29
4294a68
 
eb48f29
4294a68
 
3b7350e
4294a68
 
b10facf
 
 
4294a68
 
 
422af54
4294a68
 
 
192bf4d
ff2ee2b
4294a68
 
 
 
 
 
 
ff2ee2b
4294a68
081cd9c
690f094
4294a68
5c3986b
 
 
 
 
0ff2104
 
bdf16c0
f466dd9
5c3986b
630a72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64

def load_model():
    print("Loading the Stable Diffusion model...")
    try:
        model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
        print("Model loaded successfully.")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

def generate_image(prompt):
    model = load_model()
    try:
        if model is None:
            raise ValueError("Model not loaded properly.")
        
        print(f"Generating image with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        if output is None:
            raise ValueError("Model returned None")

        if hasattr(output, 'images') and output.images:
            print(f"Image generated successfully")
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            img_str = base64.b64encode(image_bytes).decode("utf-8")
            print("Image encoded to base64")
            print(f'img_str: {img_str[:100]}...')  # Print a snippet of the base64 string
            return img_str, None
        else:
            print(f"No images found in model output")
            raise ValueError("No images found in model output")
    except Exception as e:
        print(f"An error occurred while generating image: {e}")
        return None, str(e)

def inference(sentence_mapping, character_dict, selected_style):
    try:
        print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
        print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
        print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")

        if sentence_mapping is None or character_dict is None or selected_style is None:
            return {"error": "One or more inputs are None"}

        images = {}
        for paragraph_number, sentences in sentence_mapping.items():
            combined_sentence = " ".join(sentences)
            prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}"
            print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
            img_str, error = generate_image(prompt)
            if error:
                images[paragraph_number] = f"Error: {error}"
            else:
                images[paragraph_number] = img_str
        return images
    except Exception as e:
        print(f"An error occurred during inference: {e}")
        return {"error": str(e)}

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=3)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()