File size: 3,530 Bytes
e3fc553
a9b8939
e0ec116
c1282a1
 
b5ad13a
e3fc553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d1d03a
e0ec116
f466dd9
e3fc553
1d0b035
e0ec116
6449f8f
e3fc553
 
c301a62
6449f8f
c301a62
f466dd9
b5ad13a
c1282a1
8a0f059
f466dd9
c0cd59a
6035350
1d0b035
e0ec116
 
 
 
6449f8f
6035350
e0ec116
b5ad13a
 
6035350
e0ec116
6035350
e0ec116
 
 
6035350
e0ec116
6035350
c1282a1
b5ad13a
6035350
f466dd9
 
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
f466dd9
 
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import threading
from diffusers import AutoPipelineForText2Image
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
import traceback

class Scheduler:
    def __init__(self):
        self._step = threading.local()
        self._step.step = None

    def _init_step_index(self):
        self._step.step = 0

    @property
    def step_index(self):
        if self._step.step is None:
            self._init_step_index()
        return self._step.step

    @step_index.setter
    def step_index(self, value):
        self._step.step = value

    def step_process(self, noise_pred, t, latents, **extra_step_kwargs):
        try:
            sigma_to = self.sigmas[self.step_index + 1]
            self.step_index += 1
            # Process the step (pseudocode)
            # latents = process_latents(noise_pred, t, latents, sigma_to, **extra_step_kwargs)
            return latents
        except IndexError as e:
            print(f"Index error during step processing: {e}")
            traceback.print_exc()
            return latents

# Mocking a model class for demonstration purposes
class MockModel:
    def __init__(self):
        self.scheduler = Scheduler()

    def __call__(self, prompt, num_inference_steps, guidance_scale):
        # Simulate the inference steps
        latents = None
        for t in range(num_inference_steps):
            noise_pred = None  # Replace with actual noise prediction
            latents = self.scheduler.step_process(noise_pred, t, latents)
        return {"images": [Image.new("RGB", (512, 512))]}  # Return a dummy image for now

# Load the actual model
model = MockModel()

def generate_image(prompt):
    try:
        output = model(prompt=prompt, num_inference_steps=3, guidance_scale=0.0)
        print(f"Model output: {output}")

        # Check if the model returned images
        if isinstance(output['images'], list) and len(output['images']) > 0:
            return output['images'][0]
        else:
            raise Exception("No images returned by the model.")
            
    except Exception as e:
        print(f"Error generating image: {e}")
        traceback.print_exc()
        return None

def inference(sentence_mapping, character_dict, selected_style):
    images = []
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(generate_image, prompt) for prompt in prompts]

        for future in as_completed(futures):
            try:
                image = future.result()
                if image:
                    images.append(image)
            except Exception as e:
                print(f"Error processing prompt: {e}")
                traceback.print_exc()

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")
)

if __name__ == "__main__":
    gradio_interface.launch()