File size: 2,765 Bytes
c513221
5e2c7ed
6b1b953
192bf4d
5e2c7ed
4294a68
86743ba
4294a68
5c3986b
6b1b953
51e61c2
6b1b953
 
 
 
5d9bf5a
4294a68
eb48f29
6b1b953
 
 
4294a68
 
 
6b1b953
6292767
4294a68
6292767
 
4294a68
eb48f29
 
4294a68
eb48f29
4294a68
192bf4d
4294a68
eb48f29
4294a68
 
eb48f29
4294a68
 
3b7350e
4294a68
 
 
 
 
 
 
 
422af54
4294a68
 
 
192bf4d
4294a68
 
 
 
 
 
 
 
081cd9c
690f094
4294a68
5c3986b
 
 
 
 
bdf16c0
4294a68
bdf16c0
f466dd9
5c3986b
630a72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64

# Load the model once at the start
print("Loading the Stable Diffusion model...")
try:
    model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

def generate_image(prompt):
    try:
        if model is None:
            raise ValueError("Model not loaded properly.")
        
        print(f"Generating image with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        if output is None:
            raise ValueError("Model returned None")

        if hasattr(output, 'images') and output.images:
            print(f"Image generated")
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            img_str = base64.b64encode(image_bytes).decode("utf-8")
            print(f'img_str: {img_str}')
            return img_str, None
        else:
            print(f"No images found in model output")
            raise ValueError("No images found in model output")
    except Exception as e:
        print(f"An error occurred while generating image: {e}")
        return None, str(e)

def inference(sentence_mapping, character_dict, selected_style):
    try:
        print(f"Received sentence_mapping: {sentence_mapping}")
        print(f"Received character_dict: {character_dict}")
        print(f"Received selected_style: {selected_style}")

        if sentence_mapping is None or character_dict is None or selected_style is None:
            return {"error": "One or more inputs are None"}

        images = {}
        for paragraph_number, sentences in sentence_mapping.items():
            combined_sentence = " ".join(sentences)
            prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}"
            img_str, error = generate_image(prompt)
            if error:
                images[paragraph_number] = f"Error: {error}"
            else:
                images[paragraph_number] = img_str
        return images
    except Exception as e:
        return {"error": str(e)}

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()