Jensin's picture
Update app.py
c92c8f7 verified
raw
history blame
7.16 kB
# app.py
from datasets import load_dataset
import gradio as gr
from gradio_client import Client
import json, os, random, torch
from diffusers import FluxPipeline, AutoencoderKL
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
import spaces
# ─────────────────────────────── 1. Device ────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ─────────────────────── 2. Image / FLUX pipeline ─────────────────────────
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = (
flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
)
# ───────────────────────── 3. LLM (Zephyr-chat) ───────────────────────────
llm_client = Client("HuggingFaceH4/zephyr-chat") # public Space
CHAT_API = llm_client.view_api()[0]["api_name"] # e.g. "/chat"
def call_llm(
user_prompt: str,
system_prompt: str = "You are Zephyr, a helpful and creative assistant.",
history: list | None = None,
temperature: float = 0.7,
top_p: float = 0.9,
max_tokens: int = 1024,
) -> str:
"""
Robust wrapper around the Zephyr chat Space.
Falls back to '...' on any error so the Gradio UI never crashes.
"""
history = history or []
try:
# Zephyr-chat expects: prompt, system_prompt, history, temperature, top_p, max_new_tokens
result = llm_client.predict(
user_prompt,
system_prompt,
history,
temperature,
top_p,
max_tokens,
api_name=CHAT_API,
)
# Some Spaces return a plain string, others return the old tuple format.
return result.strip() if isinstance(result, str) else result[1][0][-1].strip()
except Exception as e:
print(f"[LLM error] {e}")
return "..."
# ───────────────────────── 4. Persona dataset ────────────────────────────
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def get_random_persona_description() -> str:
idx = random.randint(0, len(ds) - 1)
return ds[idx]["persona"]
# ─────────────────────────── 5. Prompts ─────────────────────────────────
prompt_template = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON format with these keys:
name, background, appearance, personality, skills_and_abilities, goals,
conflicts, backstory, current_situation, spoken_lines (list of strings).
Respond with **only** the JSON (no markdown, no fencing)."""
world_description_prompt = (
"Invent a short, unique and vivid world description. "
"Respond with the description only."
)
# ─────────────────────── 6. Gradio helper funcs ─────────────────────────
def get_random_world_description() -> str:
return call_llm(world_description_prompt)
@spaces.GPU(duration=75)
def infer_flux(character_json):
"""Stream intermediate images while FLUX denoises."""
for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=character_json["appearance"],
guidance_scale=3.5,
num_inference_steps=28,
width=1024,
height=1024,
generator=torch.Generator("cpu").manual_seed(0),
output_type="pil",
good_vae=good_vae,
):
yield image
def generate_character(world_description: str,
persona_description: str,
progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
prompt_template.format(
persona_description=persona_description,
world_description=world_description,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except json.JSONDecodeError:
# One retry if the LLM hallucinated
raw = call_llm(
prompt_template.format(
persona_description=persona_description,
world_description=world_description,
),
max_tokens=1024,
)
return json.loads(raw)
# ───────────────────────────── 7. UI ─────────────────────────────────────
app_description = """
- Generates a character profile (JSON) from a world + persona description.
- **Appearance** images come from [FLUX-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
- **Back-stories** come from [Zephyr-7B-Ξ²](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta).
- Personas are sampled from [FinePersonas-Lite](https://huggingface.co/datasets/MohamedRashad/FinePersonas-Lite).
Tip β†’ Write or randomise a world, then spin the persona box to see how the same
world shapes different heroes.
"""
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>πŸ§™β€β™€οΈ Character Generator</h1>")
gr.Markdown(app_description.strip())
with gr.Row():
world_description = gr.Textbox(label="World Description", lines=10, scale=4)
persona_description = gr.Textbox(
label="Persona Description",
value=get_random_persona_description(),
lines=10,
scale=1,
)
with gr.Row():
random_world_btn = gr.Button("πŸ”„ Random World", variant="secondary")
submit_btn = gr.Button("✨ Generate Character", variant="primary", scale=5)
random_persona_btn = gr.Button("πŸ”„ Random Persona", variant="secondary")
with gr.Row():
character_image = gr.Image(label="Character Image")
character_json = gr.JSON(label="Character Description")
# Hooks
submit_btn.click(
generate_character,
inputs=[world_description, persona_description],
outputs=[character_json],
).then(
infer_flux,
inputs=[character_json],
outputs=[character_image],
)
random_world_btn.click(
get_random_world_description,
outputs=[world_description],
)
random_persona_btn.click(
get_random_persona_description,
outputs=[persona_description],
)
demo.queue().launch(share=False)