# app.py from datasets import load_dataset import gradio as gr from gradio_client import Client import json, os, random, torch from diffusers import FluxPipeline, AutoencoderKL from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images import spaces # ─────────────────────────────── 1. Device ──────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ─────────────────────── 2. Image / FLUX pipeline ───────────────────────── pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ).to(device) good_vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16 ).to(device) pipe.flux_pipe_call_that_returns_an_iterable_of_images = ( flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) ) # ───────────────────────── 3. LLM (Zephyr-chat) ─────────────────────────── llm_client = Client("HuggingFaceH4/zephyr-chat") # public Space CHAT_API = llm_client.view_api()[0]["api_name"] # e.g. "/chat" def call_llm( user_prompt: str, system_prompt: str = "You are Zephyr, a helpful and creative assistant.", history: list | None = None, temperature: float = 0.7, top_p: float = 0.9, max_tokens: int = 1024, ) -> str: """ Robust wrapper around the Zephyr chat Space. Falls back to '...' on any error so the Gradio UI never crashes. """ history = history or [] try: # Zephyr-chat expects: prompt, system_prompt, history, temperature, top_p, max_new_tokens result = llm_client.predict( user_prompt, system_prompt, history, temperature, top_p, max_tokens, api_name=CHAT_API, ) # Some Spaces return a plain string, others return the old tuple format. return result.strip() if isinstance(result, str) else result[1][0][-1].strip() except Exception as e: print(f"[LLM error] {e}") return "..." # ───────────────────────── 4. Persona dataset ──────────────────────────── ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train") def get_random_persona_description() -> str: idx = random.randint(0, len(ds) - 1) return ds[idx]["persona"] # ─────────────────────────── 5. Prompts ───────────────────────────────── prompt_template = """Generate a character with this persona description: {persona_description} In a world with this description: {world_description} Write the character in JSON format with these keys: name, background, appearance, personality, skills_and_abilities, goals, conflicts, backstory, current_situation, spoken_lines (list of strings). Respond with **only** the JSON (no markdown, no fencing).""" world_description_prompt = ( "Invent a short, unique and vivid world description. " "Respond with the description only." ) # ─────────────────────── 6. Gradio helper funcs ───────────────────────── def get_random_world_description() -> str: return call_llm(world_description_prompt) @spaces.GPU(duration=75) def infer_flux(character_json): """Stream intermediate images while FLUX denoises.""" for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=character_json["appearance"], guidance_scale=3.5, num_inference_steps=28, width=1024, height=1024, generator=torch.Generator("cpu").manual_seed(0), output_type="pil", good_vae=good_vae, ): yield image def generate_character(world_description: str, persona_description: str, progress=gr.Progress(track_tqdm=True)): raw = call_llm( prompt_template.format( persona_description=persona_description, world_description=world_description, ), max_tokens=1024, ) try: return json.loads(raw) except json.JSONDecodeError: # One retry if the LLM hallucinated raw = call_llm( prompt_template.format( persona_description=persona_description, world_description=world_description, ), max_tokens=1024, ) return json.loads(raw) # ───────────────────────────── 7. UI ───────────────────────────────────── app_description = """ - Generates a character profile (JSON) from a world + persona description. - **Appearance** images come from [FLUX-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). - **Back-stories** come from [Zephyr-7B-β](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta). - Personas are sampled from [FinePersonas-Lite](https://huggingface.co/datasets/MohamedRashad/FinePersonas-Lite). Tip → Write or randomise a world, then spin the persona box to see how the same world shapes different heroes. """ with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo: gr.Markdown("

🧙‍♀️ Character Generator

") gr.Markdown(app_description.strip()) with gr.Row(): world_description = gr.Textbox(label="World Description", lines=10, scale=4) persona_description = gr.Textbox( label="Persona Description", value=get_random_persona_description(), lines=10, scale=1, ) with gr.Row(): random_world_btn = gr.Button("🔄 Random World", variant="secondary") submit_btn = gr.Button("✨ Generate Character", variant="primary", scale=5) random_persona_btn = gr.Button("🔄 Random Persona", variant="secondary") with gr.Row(): character_image = gr.Image(label="Character Image") character_json = gr.JSON(label="Character Description") # Hooks submit_btn.click( generate_character, inputs=[world_description, persona_description], outputs=[character_json], ).then( infer_flux, inputs=[character_json], outputs=[character_image], ) random_world_btn.click( get_random_world_description, outputs=[world_description], ) random_persona_btn.click( get_random_persona_description, outputs=[persona_description], ) demo.queue().launch(share=False)