Jensin's picture
Update app.py
f52b5b1 verified
raw
history blame
6.98 kB
# app.py ──────────────────────────────────────────────────────────────────────
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import FluxPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)
# ─────────────────────────── 1. Device ─────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ────────────────────── 2. FLUX image pipeline ─────────────────────────────
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
# ───────────────────────── 3. LLM client (robust) ──────────────────────────
LLM_SPACES = [
"https://huggingfaceh4-zephyr-chat.hf.space",
"meta-llama/Llama-3.3-70B-Instruct",
"huggingface-projects/gemma-2-9b-it",
]
def first_live_space(space_ids: list[str]) -> Client:
"""
Return the first Space whose /chat endpoint answers a 1-token echo.
"""
for sid in space_ids:
try:
print(f"[info] probing {sid}")
c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
_ = c.predict("ping", 8, api_name="/chat") # simple health check
print(f"[info] using {sid}")
return c
except Exception as e:
print(f"[warn] {sid} unusable β†’ {e}")
raise RuntimeError("No live chat Space found!")
llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat" # universal endpoint for TGI-style Spaces
def call_llm(prompt: str,
max_tokens: int = 256,
temperature: float = 0.6,
top_p: float = 0.9) -> str:
"""
Send a single-message chat to the Space. Extra sliders in the remote UI must
be supplied in positional order after the prompt, so we match Zephyr/Gemma:
[prompt, max_tokens, temperature, top_p, repeat_penalty, presence_penalty]
We pass only the first four; the Space will fill the rest with defaults.
"""
try:
return llm_client.predict(
prompt, max_tokens, temperature, top_p, api_name=CHAT_API
).strip()
except Exception as exc:
print(f"[error] LLM failure β†’ {exc}")
return "…"
# ──────────────────────── 4. Persona dataset ──────────────────────────────
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona() -> str:
return ds[random.randint(0, len(ds) - 1)]["persona"]
# ─────────────────────────── 5. Text prompts ───────────────────────────────
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities,
goals, conflicts, backstory, current_situation,
spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. "
"Respond with the description only."
)
# ───────────────────────── 6. Helper functions ─────────────────────────────
def random_world() -> str:
return call_llm(WORLD_PROMPT, max_tokens=120)
@spaces.GPU(duration=75)
def infer_flux(character_json):
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=character_json["appearance"],
guidance_scale=3.5,
num_inference_steps=28,
width=1024,
height=1024,
generator=torch.Generator("cpu").manual_seed(0),
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_character(world_desc: str, persona_desc: str,
progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except json.JSONDecodeError:
# retry once if the model didn’t return valid JSON
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
return json.loads(raw)
# ─────────────────────────── 7. Gradio UI ──────────────────────────────────
DESCRIPTION = """
* Generates a JSON character sheet from a world + persona.
* Appearance images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β†’ Shuffle the world then persona for rapid inspiration.
"""
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>πŸ§β€β™‚οΈ Character Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
infer_flux, [json_out], [img_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=False)