import os from io import BytesIO from PIL import Image from diffusers import AutoPipelineForText2Image import gradio as gr import base64 # Load the model once at the start print("Loading the Stable Diffusion model...") try: model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") model = None def generate_image(prompt): try: if model is None: raise ValueError("Model not loaded properly.") print(f"Generating image with prompt: {prompt}") output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0) print(f"Model output: {output}") if output is None: raise ValueError("Model returned None") if hasattr(output, 'images') and output.images: print(f"Image generated") image = output.images[0] buffered = BytesIO() image.save(buffered, format="JPEG") image_bytes = buffered.getvalue() img_str = base64.b64encode(image_bytes).decode("utf-8") print(f'img_str: {img_str}') return img_str, None else: print(f"No images found in model output") raise ValueError("No images found in model output") except Exception as e: print(f"An error occurred while generating image: {e}") return None, str(e) def inference(sentence_mapping, character_dict, selected_style): try: print(f"Received sentence_mapping: {sentence_mapping}") print(f"Received character_dict: {character_dict}") print(f"Received selected_style: {selected_style}") if sentence_mapping is None or character_dict is None or selected_style is None: return {"error": "One or more inputs are None"} images = {} for paragraph_number, sentences in sentence_mapping.items(): combined_sentence = " ".join(sentences) prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}" img_str, error = generate_image(prompt) if error: images[paragraph_number] = f"Error: {error}" else: images[paragraph_number] = img_str return images except Exception as e: return {"error": str(e)} gradio_interface = gr.Interface( fn=inference, inputs=[ gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style") ], outputs="json" ) if __name__ == "__main__": print("Launching Gradio interface...") gradio_interface.launch()