Spaces:
Runtime error
Runtime error
File size: 4,200 Bytes
c513221 109adde 441106f 9da79fd 5e2c7ed 5d9bf5a b85438c 5d9bf5a 86743ba 5d9bf5a 441106f 5d9bf5a 441106f 5d9bf5a 441106f 5d9bf5a 441106f 5d9bf5a 441106f 5d9bf5a 441106f 5d9bf5a 441106f 5d9bf5a 3b7350e 834f7ba 109adde 690f094 109adde 690f094 c7f120b d253f4a 690f094 bdf16c0 834f7ba 5d9bf5a 30c04e8 441106f 5d9bf5a c7f120b cfeca25 c7f120b 690f094 081cd9c 5e2c7ed a04441d 109adde 834f7ba 109adde 081cd9c 690f094 bdf16c0 690f094 bdf16c0 f466dd9 c7f120b c14304d c7f120b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import asyncio
import time
from generate_prompts import generate_prompt
from diffusers import AutoPipelineForText2Image
from io import BytesIO
import gradio as gr
import ray
ray.init()
@ray.remote
class ModelActor:
def __init__(self):
print("Loading the model...")
self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
def generate_image(self, prompt, prompt_name):
start_time = time.time()
process_id = os.getpid()
try:
print(f"[{process_id}] Generating response for {prompt_name} with prompt: {prompt}")
output = self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"[{process_id}] Output for {prompt_name}: {output}")
if isinstance(output.images, list) and len(output.images) > 0:
image = output.images[0]
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
end_time = time.time()
print(f"[{process_id}] Image bytes length for {prompt_name}: {len(image_bytes)}")
print(f"[{process_id}] Time taken for {prompt_name}: {end_time - start_time} seconds")
return image_bytes
except Exception as e:
print(f"[{process_id}] Error saving image for {prompt_name}: {e}")
return None
else:
raise Exception(f"[{process_id}] No images returned by the model for {prompt_name}.")
except Exception as e:
print(f"[{process_id}] Error generating image for {prompt_name}: {e}")
return None
async def queue_api_calls(sentence_mapping, character_dict, selected_style):
print(f"queue_api_calls invoked with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
prompts = []
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
num_prompts = len(prompts)
num_actors = min(num_prompts, 20) # Limit to a maximum of 20 actors
model_actors = [ModelActor.remote() for _ in range(num_actors)]
tasks = [model_actors[i % num_actors].generate_image.remote(prompt, f"Prompt {paragraph_number}") for i, (paragraph_number, prompt) in enumerate(prompts)]
print("Tasks created for image generation.")
responses = await asyncio.gather(*[asyncio.to_thread(ray.get, task) for task in tasks])
print("Responses received from image generation tasks.")
images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
print(f"Images generated: {images}")
return images
def process_prompt(sentence_mapping, character_dict, selected_style):
print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
print("Event loop created.")
cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
print(f"process_prompt completed with return value: {cmpt_return}")
return cmpt_return
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()
print("Gradio interface launched.")
|