Spaces:
Running
on
Zero
Running
on
Zero
# app.py | |
from datasets import load_dataset | |
import gradio as gr | |
from gradio_client import Client | |
import json, os, random, torch | |
from diffusers import FluxPipeline, AutoencoderKL | |
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images | |
import spaces | |
# βββββββββββββββββββββββββββββββ 1. Device ββββββββββββββββββββββββββββββββ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# βββββββββββββββββββββββ 2. Image / FLUX pipeline βββββββββββββββββββββββββ | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
good_vae = AutoencoderKL.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="vae", | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = ( | |
flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) | |
) | |
# βββββββββββββββββββββββββ 3. LLM (Zephyr-chat) βββββββββββββββββββββββββββ | |
llm_client = Client("HuggingFaceH4/zephyr-chat") # public Space | |
CHAT_API = llm_client.view_api()[0]["api_name"] # e.g. "/chat" | |
def call_llm( | |
user_prompt: str, | |
system_prompt: str = "You are Zephyr, a helpful and creative assistant.", | |
history: list | None = None, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
max_tokens: int = 1024, | |
) -> str: | |
""" | |
Robust wrapper around the Zephyr chat Space. | |
Falls back to '...' on any error so the Gradio UI never crashes. | |
""" | |
history = history or [] | |
try: | |
# Zephyr-chat expects: prompt, system_prompt, history, temperature, top_p, max_new_tokens | |
result = llm_client.predict( | |
user_prompt, | |
system_prompt, | |
history, | |
temperature, | |
top_p, | |
max_tokens, | |
api_name=CHAT_API, | |
) | |
# Some Spaces return a plain string, others return the old tuple format. | |
return result.strip() if isinstance(result, str) else result[1][0][-1].strip() | |
except Exception as e: | |
print(f"[LLM error] {e}") | |
return "..." | |
# βββββββββββββββββββββββββ 4. Persona dataset ββββββββββββββββββββββββββββ | |
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train") | |
def get_random_persona_description() -> str: | |
idx = random.randint(0, len(ds) - 1) | |
return ds[idx]["persona"] | |
# βββββββββββββββββββββββββββ 5. Prompts βββββββββββββββββββββββββββββββββ | |
prompt_template = """Generate a character with this persona description: | |
{persona_description} | |
In a world with this description: | |
{world_description} | |
Write the character in JSON format with these keys: | |
name, background, appearance, personality, skills_and_abilities, goals, | |
conflicts, backstory, current_situation, spoken_lines (list of strings). | |
Respond with **only** the JSON (no markdown, no fencing).""" | |
world_description_prompt = ( | |
"Invent a short, unique and vivid world description. " | |
"Respond with the description only." | |
) | |
# βββββββββββββββββββββββ 6. Gradio helper funcs βββββββββββββββββββββββββ | |
def get_random_world_description() -> str: | |
return call_llm(world_description_prompt) | |
def infer_flux(character_json): | |
"""Stream intermediate images while FLUX denoises.""" | |
for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=character_json["appearance"], | |
guidance_scale=3.5, | |
num_inference_steps=28, | |
width=1024, | |
height=1024, | |
generator=torch.Generator("cpu").manual_seed(0), | |
output_type="pil", | |
good_vae=good_vae, | |
): | |
yield image | |
def generate_character(world_description: str, | |
persona_description: str, | |
progress=gr.Progress(track_tqdm=True)): | |
raw = call_llm( | |
prompt_template.format( | |
persona_description=persona_description, | |
world_description=world_description, | |
), | |
max_tokens=1024, | |
) | |
try: | |
return json.loads(raw) | |
except json.JSONDecodeError: | |
# One retry if the LLM hallucinated | |
raw = call_llm( | |
prompt_template.format( | |
persona_description=persona_description, | |
world_description=world_description, | |
), | |
max_tokens=1024, | |
) | |
return json.loads(raw) | |
# βββββββββββββββββββββββββββββ 7. UI βββββββββββββββββββββββββββββββββββββ | |
app_description = """ | |
- Generates a character profile (JSON) from a world + persona description. | |
- **Appearance** images come from [FLUX-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). | |
- **Back-stories** come from [Zephyr-7B-Ξ²](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta). | |
- Personas are sampled from [FinePersonas-Lite](https://huggingface.co/datasets/MohamedRashad/FinePersonas-Lite). | |
Tip β Write or randomise a world, then spin the persona box to see how the same | |
world shapes different heroes. | |
""" | |
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo: | |
gr.Markdown("<h1 style='text-align:center'>π§ββοΈ Character Generator</h1>") | |
gr.Markdown(app_description.strip()) | |
with gr.Row(): | |
world_description = gr.Textbox(label="World Description", lines=10, scale=4) | |
persona_description = gr.Textbox( | |
label="Persona Description", | |
value=get_random_persona_description(), | |
lines=10, | |
scale=1, | |
) | |
with gr.Row(): | |
random_world_btn = gr.Button("π Random World", variant="secondary") | |
submit_btn = gr.Button("β¨ Generate Character", variant="primary", scale=5) | |
random_persona_btn = gr.Button("π Random Persona", variant="secondary") | |
with gr.Row(): | |
character_image = gr.Image(label="Character Image") | |
character_json = gr.JSON(label="Character Description") | |
# Hooks | |
submit_btn.click( | |
generate_character, | |
inputs=[world_description, persona_description], | |
outputs=[character_json], | |
).then( | |
infer_flux, | |
inputs=[character_json], | |
outputs=[character_image], | |
) | |
random_world_btn.click( | |
get_random_world_description, | |
outputs=[world_description], | |
) | |
random_persona_btn.click( | |
get_random_persona_description, | |
outputs=[persona_description], | |
) | |
demo.queue().launch(share=False) | |