Jensin's picture
Update app.py
745adc6 verified
raw
history blame
5.9 kB
# app.py β€” Robust Character Generator Space
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import FluxPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)
# 1. Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 2. FLUX image pipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
# 3. LLM client (robust)
LLM_SPACES = [
"https://huggingfaceh4-zephyr-chat.hf.space",
"meta-llama/Llama-3.3-70B-Instruct",
"huggingface-projects/gemma-2-9b-it",
]
def first_live_space(space_ids: list[str]) -> Client:
for sid in space_ids:
try:
print(f"[info] probing {sid}")
c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
_ = c.predict("ping", 8, api_name="/chat")
print(f"[info] using {sid}")
return c
except Exception as e:
print(f"[warn] {sid} unusable β†’ {e}")
raise RuntimeError("No live chat Space found!")
llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat"
def call_llm(prompt: str,
max_tokens: int = 256,
temperature: float = 0.6,
top_p: float = 0.9) -> str:
try:
return llm_client.predict(
prompt, max_tokens, temperature, top_p, api_name=CHAT_API
).strip()
except Exception as exc:
print(f"[error] LLM failure β†’ {exc}")
return "…"
# 4. Persona dataset
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona() -> str:
return ds[random.randint(0, len(ds) - 1)]["persona"]
# 5. Text prompts
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities,
goals, conflicts, backstory, current_situation,
spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. "
"Respond with the description only."
)
# 6. Helper functions
def random_world() -> str:
return call_llm(WORLD_PROMPT, max_tokens=120)
def safe_json_parse(raw):
"""Try to parse JSON, return None if fail, and log."""
try:
return json.loads(raw)
except Exception as e:
print(f"[ERROR] JSON parsing failed: {e}")
print(f"[DEBUG] Raw output: {raw[:1000]}")
return None
@spaces.GPU(duration=75)
def infer_flux(character_json):
# Defensive: If not a dict or missing appearance, bail out
if not isinstance(character_json, dict) or "appearance" not in character_json:
print("[ERROR] No valid appearance to generate image.")
return None
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=character_json["appearance"],
guidance_scale=3.5,
num_inference_steps=28,
width=1024,
height=1024,
generator=torch.Generator("cpu").manual_seed(0),
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_character(world_desc: str, persona_desc: str,
progress=gr.Progress(track_tqdm=True)):
# First attempt
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
character = safe_json_parse(raw)
if character:
return character
# Retry once
raw2 = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
character2 = safe_json_parse(raw2)
if character2:
return character2
# If both fail, return error and raw outputs for debugging
return {
"error": "LLM did not return valid JSON after 2 attempts.",
"first_raw": raw,
"second_raw": raw2,
"tip": "Check your LLM prompt and output. Try regenerating.",
}
# 7. Gradio UI
DESCRIPTION = """
* Generates a JSON character sheet from a world + persona.
* Appearance images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β†’ Shuffle the world then persona for rapid inspiration.
"""
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>πŸ§β€β™‚οΈ Character Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
infer_flux, [json_out], [img_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=False)