text2image_1 / app.py
RanM's picture
Update app.py
ca1d41c verified
raw
history blame
1.54 kB
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import base64
from io import BytesIO
from generate_propmts.py import generate_prompt
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
def generate_image(text, sentence_mapping, character_dict, selected_style):
try:
prompt,_ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
buffered = BytesIO()
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
if isinstance(result, img_str):
image_bytes = base64.b64decode(result)
return image_bytes
except Exception as e:
return None
def inference(prompt):
# Dictionary to store images results
images = {}
print(f"Received grouped_sentences: {grouped_sentences}")
# Debugging statement
with concurrent.images.ThreadPoolExecutor() as executor:
for paragraph_number, sentences in grouped_sentences.items():
combined_sentence = " ".join(sentences)
images[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, general_descriptions, selected_style)
return images
gradio_interface = gr.Interface(
fn=inference,
inputs="text",
outputs="text" # Change output to text to return base64 string
)
if __name__ == "__main__":
gradio_interface.launch()