Spaces:
Runtime error
Runtime error
File size: 2,144 Bytes
3486e1a f466dd9 a9b8939 9e09422 cd715cb b5ad13a a9b8939 6d1d03a c0cd59a f466dd9 c0cd59a a9b8939 8a0f059 3486e1a 8a0f059 3486e1a f466dd9 b5ad13a 8a0f059 f466dd9 c0cd59a 8a0f059 c0cd59a 3486e1a b5ad13a 8a0f059 3486e1a b5ad13a ca1d41c f466dd9 d05fa5e 3486e1a d05fa5e 3486e1a f466dd9 a9b8939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import json
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import base64
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
def generate_image(text, sentence_mapping, character_dict, selected_style):
try:
prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
buffered = BytesIO()
buffered.write(image.tobytes())
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
except Exception as e:
print(f"Error generating image: {e}")
return None
def inference(sentence_mapping, character_dict, selected_style):
images = {}
print(f'sentence_mapping:{sentence_mapping}, character_dict:{character_dict}, selected_style:{selected_style}')
# Parse sentence_mapping JSON string into a dictionary
try:
grouped_sentences = json.loads(sentence_mapping)
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
return {"error": "Invalid JSON input for sentence_mapping"}
with ThreadPoolExecutor() as executor:
futures = {}
for paragraph_number, sentences in grouped_sentences.items():
combined_sentence = " ".join(sentences)
futures[paragraph_number] = executor.submit(generate_image, combined_sentence, grouped_sentences, character_dict, selected_style)
for paragraph_number, future in futures.items():
images[paragraph_number] = future.result()
return images
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.Textbox(label="Sentence Mapping (JSON)"),
gr.Textbox(label="Character Dict (JSON)"),
gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
],
outputs="text"
)
if __name__ == "__main__":
gradio_interface.launch()
|