Jensin's picture
Update app.py
9c89ed0 verified
raw
history blame
7.07 kB
from datasets import load_dataset
import gradio as gr
from gradio_client import Client
import json, os, random, torch, spaces
from diffusers import FluxPipeline, AutoencoderKL
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
# ───────────────────────────── 1. Device ────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ─────────────────────── 2. Image / FLUX pipeline ───────────────────────────
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = (
flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
)
# ───────────────────────── 3. LLM client (robust) ───────────────────────────
def _first_working_client(candidates: list[str]) -> Client:
"""
Try a list of Space URLs / repo-ids, return the first that gives a JSON config.
"""
for src in candidates:
try:
print(f"[info] Trying LLM Space: {src}")
c = Client(src, hf_token=os.getenv("HF_TOKEN")) # token optional
# If this passes, the config was parsed as JSON
c.view_api()
print(f"[info] Selected LLM Space: {src}")
return c
except Exception as e:
print(f"[warn] {src} not usable β†’ {e}")
raise RuntimeError("No usable LLM Space found!")
LLM_CANDIDATES = [
"https://huggingfaceh4-zephyr-chat.hf.space", # direct URL
"HuggingFaceH4/zephyr-chat", # repo slug
"huggingface-projects/gemma-2-9b-it", # fallback Space
]
llm_client = _first_working_client(LLM_CANDIDATES)
CHAT_API = llm_client.view_api()[0]["api_name"] # safest way to get endpoint
def call_llm(
user_prompt: str,
system_prompt: str = "You are a helpful creative assistant.",
history: list | None = None,
temperature: float = 0.7,
top_p: float = 0.9,
max_tokens: int = 1024,
) -> str:
"""
Unified chat wrapper – works for both Zephyr and Gemma Spaces.
"""
history = history or []
try:
result = llm_client.predict(
user_prompt,
system_prompt,
history,
temperature,
top_p,
max_tokens,
api_name=CHAT_API,
)
# Some Spaces return string, some return (…, history) tuple
if isinstance(result, str):
return result.strip()
return result[1][0][-1].strip()
except Exception as e:
print(f"[error] LLM call failed β†’ {e}")
return "…"
# ───────────────────────── 4. Persona dataset ───────────────────────────────
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona() -> str:
return ds[random.randint(0, len(ds) - 1)]["persona"]
# ─────────────────────────── 5. Prompt templates ───────────────────────────
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities,
goals, conflicts, backstory, current_situation,
spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. "
"Respond with the description only."
)
# ─────────────────────── 6. Helper functions ───────────────────────────────
def random_world() -> str:
return call_llm(WORLD_PROMPT)
@spaces.GPU(duration=75)
def infer_flux(character_json):
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=character_json["appearance"],
guidance_scale=3.5,
num_inference_steps=28,
width=1024,
height=1024,
generator=torch.Generator("cpu").manual_seed(0),
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_character(world_desc: str, persona_desc: str,
progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except json.JSONDecodeError:
# One retry
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
return json.loads(raw)
# ───────────────────────────── 7. UI ────────────────────────────────────────
DESCRIPTION = """
* Generates a character sheet (JSON) from a world + persona.
* Appearance images via **FLUX-dev**; narrative via **Zephyr-chat** (or Gemma fallback).
* Personas come from **FinePersonas-Lite**.
Tip β†’ Spin the world, then shuffle personas to see very different heroes.
"""
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>πŸ§šβ€β™€οΈ Character Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
infer_flux, [json_out], [img_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=False)