Spaces:
Runtime error
Runtime error
File size: 2,044 Bytes
f466dd9 a9b8939 9e09422 cd715cb b5ad13a 1adc78a b5ad13a a9b8939 6d1d03a c0cd59a f466dd9 c0cd59a a97be70 1d0b035 8a0f059 a97be70 1adc78a f466dd9 b5ad13a 8a0f059 f466dd9 c0cd59a 8a0f059 1d0b035 1adc78a b5ad13a 8a0f059 1adc78a b5ad13a ca1d41c f466dd9 d05fa5e 1adc78a d05fa5e 1adc78a f466dd9 a9b8939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
import json
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
def generate_image(text, sentence_mapping, character_dict, selected_style):
try:
prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
print(f"Generated prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output: {output}")
image = output.images[0]
buffered = BytesIO()
image_bytes = buffered.getvalue()
return image_bytes
except Exception as e:
print(f"Error generating image: {e}")
return None
def inference(sentence_mapping, character_dict, selected_style):
images = {}
print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
# Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
grouped_sentences = sentence_mapping
with ThreadPoolExecutor() as executor:
futures = {}
for paragraph_number, sentences in grouped_sentences.items():
combined_sentence = " ".join(sentences)
futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)
for paragraph_number, future in futures.items():
images[paragraph_number] = future.result()
return images
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
],
outputs="json"
)
if __name__ == "__main__":
gradio_interface.launch()
|