Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,978 Bytes
0a6a466 f2d6ac6 0a6a466 f2d6ac6 0a6a466 f2d6ac6 0a6a466 f2d6ac6 0a6a466 c92c8f7 9c89ed0 c92c8f7 9c89ed0 c92c8f7 0a6a466 c92c8f7 0a6a466 f52b5b1 0a6a466 9c89ed0 0a6a466 9c89ed0 0a6a466 9c89ed0 0a6a466 9c89ed0 0a6a466 9c89ed0 0a6a466 9c89ed0 0a6a466 c92c8f7 0a6a466 c92c8f7 0a6a466 9c89ed0 f2d6ac6 0a6a466 f2d6ac6 9c89ed0 f2d6ac6 0a6a466 9c89ed0 f2d6ac6 c92c8f7 f2d6ac6 9c89ed0 c92c8f7 9c89ed0 c92c8f7 9c89ed0 c92c8f7 0a6a466 9c89ed0 0a6a466 f2d6ac6 b32bd6a 7173d5d 9c89ed0 7173d5d f2d6ac6 9c89ed0 f2d6ac6 9c89ed0 c92c8f7 9c89ed0 7173d5d c92c8f7 7173d5d c92c8f7 0a6a466 c92c8f7 9c89ed0 c92c8f7 f2d6ac6 0a6a466 9c89ed0 0a6a466 7173d5d 0a6a466 fbcd34b c92c8f7 0a6a466 9c89ed0 c92c8f7 9c89ed0 f2d6ac6 c92c8f7 9c89ed0 c92c8f7 9c89ed0 c92c8f7 9c89ed0 c92c8f7 9c89ed0 f2d6ac6 c92c8f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import FluxPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)
# βββββββββββββββββββββββββββ 1. Device βββββββββββββββββββββββββββββββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ββββββββββββββββββββββ 2. FLUX image pipeline βββββββββββββββββββββββββββββ
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
# βββββββββββββββββββββββββ 3. LLM client (robust) ββββββββββββββββββββββββββ
LLM_SPACES = [
"https://huggingfaceh4-zephyr-chat.hf.space",
"meta-llama/Llama-3.3-70B-Instruct",
"huggingface-projects/gemma-2-9b-it",
]
def first_live_space(space_ids: list[str]) -> Client:
"""
Return the first Space whose /chat endpoint answers a 1-token echo.
"""
for sid in space_ids:
try:
print(f"[info] probing {sid}")
c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
_ = c.predict("ping", 8, api_name="/chat") # simple health check
print(f"[info] using {sid}")
return c
except Exception as e:
print(f"[warn] {sid} unusable β {e}")
raise RuntimeError("No live chat Space found!")
llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat" # universal endpoint for TGI-style Spaces
def call_llm(prompt: str,
max_tokens: int = 256,
temperature: float = 0.6,
top_p: float = 0.9) -> str:
"""
Send a single-message chat to the Space. Extra sliders in the remote UI must
be supplied in positional order after the prompt, so we match Zephyr/Gemma:
[prompt, max_tokens, temperature, top_p, repeat_penalty, presence_penalty]
We pass only the first four; the Space will fill the rest with defaults.
"""
try:
return llm_client.predict(
prompt, max_tokens, temperature, top_p, api_name=CHAT_API
).strip()
except Exception as exc:
print(f"[error] LLM failure β {exc}")
return "β¦"
# ββββββββββββββββββββββββ 4. Persona dataset ββββββββββββββββββββββββββββββ
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona() -> str:
return ds[random.randint(0, len(ds) - 1)]["persona"]
# βββββββββββββββββββββββββββ 5. Text prompts βββββββββββββββββββββββββββββββ
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities,
goals, conflicts, backstory, current_situation,
spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. "
"Respond with the description only."
)
# βββββββββββββββββββββββββ 6. Helper functions βββββββββββββββββββββββββββββ
def random_world() -> str:
return call_llm(WORLD_PROMPT, max_tokens=120)
@spaces.GPU(duration=75)
def infer_flux(character_json):
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=character_json["appearance"],
guidance_scale=3.5,
num_inference_steps=28,
width=1024,
height=1024,
generator=torch.Generator("cpu").manual_seed(0),
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_character(world_desc: str, persona_desc: str,
progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except json.JSONDecodeError:
# retry once if the model didnβt return valid JSON
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
return json.loads(raw)
# βββββββββββββββββββββββββββ 7. Gradio UI ββββββββββββββββββββββββββββββββββ
DESCRIPTION = """
* Generates a JSON character sheet from a world + persona.
* Appearance images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β Shuffle the world then persona for rapid inspiration.
"""
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>π§ββοΈ Character Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("π Random World", variant="secondary")
btn_generate = gr.Button("β¨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("π Random Persona", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
infer_flux, [json_out], [img_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=False)
|