Spaces:
Runtime error
Runtime error
File size: 2,768 Bytes
f466dd9 a9b8939 9e09422 cd715cb b5ad13a 1adc78a b5ad13a a9b8939 6d1d03a c301a62 8a555c8 c301a62 c0cd59a f466dd9 c0cd59a a97be70 c301a62 8a555c8 1d0b035 8a555c8 c301a62 f466dd9 b5ad13a 8a0f059 f466dd9 c0cd59a 8a0f059 1d0b035 1adc78a b5ad13a 8a0f059 1adc78a b5ad13a ca1d41c f466dd9 d05fa5e 1adc78a d05fa5e c301a62 f466dd9 a9b8939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor
import json
# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
# Helper function to truncate prompt to fit the model's maximum sequence length
def truncate_prompt(prompt, max_length=77):
tokens = prompt.split()
if len(tokens) > max_length:
return ' '.join(tokens[:max_length])
print("len of tokens:", len(tokens))
print("len of tokens:", len(prompt))
return prompt
def generate_image(text, sentence_mapping, character_dict, selected_style):
try:
prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
print(f"Generated prompt: {prompt}")
# Truncate prompt if necessary
prompt = truncate_prompt(prompt)
print(f"truncate_prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output: {output}")
print("len of output:", len(output))
# Check if the model returned images
if output.images:
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
return image_bytes
else:
raise Exception("No images returned by the model.")
except Exception as e:
print(f"Error generating image: {e}")
return None
def inference(sentence_mapping, character_dict, selected_style):
images = {}
print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
# Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
grouped_sentences = sentence_mapping
with ThreadPoolExecutor() as executor:
futures = {}
for paragraph_number, sentences in grouped_sentences.items():
combined_sentence = " ".join(sentences)
futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)
for paragraph_number, future in futures.items():
images[paragraph_number] = future.result()
return images
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
],
outputs="json"
)
if __name__ == "__main__":
gradio_interface.launch()
|