File size: 2,781 Bytes
c513221
5e2c7ed
6b1b953
1819f08
5e2c7ed
eb48f29
422af54
86743ba
5c3986b
 
6b1b953
51e61c2
6b1b953
 
 
 
5d9bf5a
422af54
eb48f29
6b1b953
 
 
422af54
6b1b953
422af54
6b1b953
6292767
422af54
6292767
 
422af54
eb48f29
 
 
 
422af54
 
eb48f29
422af54
 
eb48f29
422af54
 
3b7350e
422af54
109adde
422af54
 
 
109adde
422af54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
081cd9c
690f094
422af54
5c3986b
 
 
 
 
bdf16c0
422af54
bdf16c0
f466dd9
5c3986b
630a72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from io import BytesIO
from PIL import Image
from transformers import AutoPipelineForText2Image
import gradio as gr
from generate_prompts import generate_prompt
import base64

# Load the model once at the start
print("Loading the Stable Diffusion model...")
try:
    model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

def generate_image(prompt):
    try:
        if model is None:
            raise ValueError("Model not loaded properly.")
        
        print(f"Generating image with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        if output is None:
            raise ValueError("Model returned None")

        if hasattr(output, 'images') and output.images:
            print(f"Image generated")
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            img_str = base64.b64encode(image_bytes).decode("utf-8")
            return img_str, None
        else:
            print(f"No images found in model output")
            raise ValueError("No images found in model output")
    except Exception as e:
        print(f"An error occurred while generating image: {e}")
        return None, str(e)

def inference(sentence_mapping, character_dict, selected_style):
    try:
        print(f"Received sentence_mapping: {sentence_mapping}")
        print(f"Received character_dict: {character_dict}")
        print(f"Received selected_style: {selected_style}")

        if sentence_mapping is None or character_dict is None or selected_style is None:
            return {"error": "One or more inputs are None"}

        images = {}
        for paragraph_number, sentences in sentence_mapping.items():
            combined_sentence = " ".join(sentences)
            prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
            img_str, error = generate_image(prompt)
            if error:
                images[paragraph_number] = f"Error: {error}"
            else:
                images[paragraph_number] = img_str
        return images
    except Exception as e:
        return {"error": str(e)}

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json"
)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()