Spaces:
Runtime error
Runtime error
import threading | |
from diffusers import AutoPipelineForText2Image | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from PIL import Image | |
import traceback | |
class Scheduler: | |
def __init__(self): | |
self._step = threading.local() | |
self._step.step = None | |
def _init_step_index(self): | |
self._step.step = 0 | |
def step_index(self): | |
if self._step.step is None: | |
self._init_step_index() | |
return self._step.step | |
def step_index(self, value): | |
self._step.step = value | |
def step_process(self, noise_pred, t, latents, **extra_step_kwargs): | |
try: | |
sigma_to = self.sigmas[self.step_index + 1] | |
self.step_index += 1 | |
# Process the step (pseudocode) | |
# latents = process_latents(noise_pred, t, latents, sigma_to, **extra_step_kwargs) | |
return latents | |
except IndexError as e: | |
print(f"Index error during step processing: {e}") | |
traceback.print_exc() | |
return latents | |
# Mocking a model class for demonstration purposes | |
class MockModel: | |
def __init__(self): | |
self.scheduler = Scheduler() | |
def __call__(self, prompt, num_inference_steps, guidance_scale): | |
# Simulate the inference steps | |
latents = None | |
for t in range(num_inference_steps): | |
noise_pred = None # Replace with actual noise prediction | |
latents = self.scheduler.step_process(noise_pred, t, latents) | |
return {"images": [Image.new("RGB", (512, 512))]} # Return a dummy image for now | |
# Load the actual model | |
model = MockModel() | |
def generate_image(prompt): | |
try: | |
output = model(prompt=prompt, num_inference_steps=3, guidance_scale=0.0) | |
print(f"Model output: {output}") | |
# Check if the model returned images | |
if isinstance(output['images'], list) and len(output['images']) > 0: | |
return output['images'][0] | |
else: | |
raise Exception("No images returned by the model.") | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
traceback.print_exc() | |
return None | |
def inference(sentence_mapping, character_dict, selected_style): | |
images = [] | |
print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}') | |
prompts = [] | |
for paragraph_number, sentences in sentence_mapping.items(): | |
combined_sentence = " ".join(sentences) | |
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style) | |
prompts.append(prompt) | |
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}") | |
with ThreadPoolExecutor() as executor: | |
futures = [executor.submit(generate_image, prompt) for prompt in prompts] | |
for future in as_completed(futures): | |
try: | |
image = future.result() | |
if image: | |
images.append(image) | |
except Exception as e: | |
print(f"Error processing prompt: {e}") | |
traceback.print_exc() | |
return images | |
gradio_interface = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.JSON(label="Sentence Mapping"), | |
gr.JSON(label="Character Dict"), | |
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style") | |
], | |
outputs=gr.Gallery(label="Generated Images") | |
) | |
if __name__ == "__main__": | |
gradio_interface.launch() | |