Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,159 Bytes
c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 b32bd6a f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 b32bd6a 7173d5d c92c8f7 f2d6ac6 7173d5d f2d6ac6 7173d5d f2d6ac6 c92c8f7 7173d5d c92c8f7 7173d5d c92c8f7 f2d6ac6 c92c8f7 fbcd34b c92c8f7 7173d5d c92c8f7 fbcd34b c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 f2d6ac6 c92c8f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
# app.py
from datasets import load_dataset
import gradio as gr
from gradio_client import Client
import json, os, random, torch
from diffusers import FluxPipeline, AutoencoderKL
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
import spaces
# βββββββββββββββββββββββββββββββ 1. Device ββββββββββββββββββββββββββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# βββββββββββββββββββββββ 2. Image / FLUX pipeline βββββββββββββββββββββββββ
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = (
flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
)
# βββββββββββββββββββββββββ 3. LLM (Zephyr-chat) βββββββββββββββββββββββββββ
llm_client = Client("HuggingFaceH4/zephyr-chat") # public Space
CHAT_API = llm_client.view_api()[0]["api_name"] # e.g. "/chat"
def call_llm(
user_prompt: str,
system_prompt: str = "You are Zephyr, a helpful and creative assistant.",
history: list | None = None,
temperature: float = 0.7,
top_p: float = 0.9,
max_tokens: int = 1024,
) -> str:
"""
Robust wrapper around the Zephyr chat Space.
Falls back to '...' on any error so the Gradio UI never crashes.
"""
history = history or []
try:
# Zephyr-chat expects: prompt, system_prompt, history, temperature, top_p, max_new_tokens
result = llm_client.predict(
user_prompt,
system_prompt,
history,
temperature,
top_p,
max_tokens,
api_name=CHAT_API,
)
# Some Spaces return a plain string, others return the old tuple format.
return result.strip() if isinstance(result, str) else result[1][0][-1].strip()
except Exception as e:
print(f"[LLM error] {e}")
return "..."
# βββββββββββββββββββββββββ 4. Persona dataset ββββββββββββββββββββββββββββ
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def get_random_persona_description() -> str:
idx = random.randint(0, len(ds) - 1)
return ds[idx]["persona"]
# βββββββββββββββββββββββββββ 5. Prompts βββββββββββββββββββββββββββββββββ
prompt_template = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON format with these keys:
name, background, appearance, personality, skills_and_abilities, goals,
conflicts, backstory, current_situation, spoken_lines (list of strings).
Respond with **only** the JSON (no markdown, no fencing)."""
world_description_prompt = (
"Invent a short, unique and vivid world description. "
"Respond with the description only."
)
# βββββββββββββββββββββββ 6. Gradio helper funcs βββββββββββββββββββββββββ
def get_random_world_description() -> str:
return call_llm(world_description_prompt)
@spaces.GPU(duration=75)
def infer_flux(character_json):
"""Stream intermediate images while FLUX denoises."""
for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=character_json["appearance"],
guidance_scale=3.5,
num_inference_steps=28,
width=1024,
height=1024,
generator=torch.Generator("cpu").manual_seed(0),
output_type="pil",
good_vae=good_vae,
):
yield image
def generate_character(world_description: str,
persona_description: str,
progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
prompt_template.format(
persona_description=persona_description,
world_description=world_description,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except json.JSONDecodeError:
# One retry if the LLM hallucinated
raw = call_llm(
prompt_template.format(
persona_description=persona_description,
world_description=world_description,
),
max_tokens=1024,
)
return json.loads(raw)
# βββββββββββββββββββββββββββββ 7. UI βββββββββββββββββββββββββββββββββββββ
app_description = """
- Generates a character profile (JSON) from a world + persona description.
- **Appearance** images come from [FLUX-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
- **Back-stories** come from [Zephyr-7B-Ξ²](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta).
- Personas are sampled from [FinePersonas-Lite](https://huggingface.co/datasets/MohamedRashad/FinePersonas-Lite).
Tip β Write or randomise a world, then spin the persona box to see how the same
world shapes different heroes.
"""
with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>π§ββοΈ Character Generator</h1>")
gr.Markdown(app_description.strip())
with gr.Row():
world_description = gr.Textbox(label="World Description", lines=10, scale=4)
persona_description = gr.Textbox(
label="Persona Description",
value=get_random_persona_description(),
lines=10,
scale=1,
)
with gr.Row():
random_world_btn = gr.Button("π Random World", variant="secondary")
submit_btn = gr.Button("β¨ Generate Character", variant="primary", scale=5)
random_persona_btn = gr.Button("π Random Persona", variant="secondary")
with gr.Row():
character_image = gr.Image(label="Character Image")
character_json = gr.JSON(label="Character Description")
# Hooks
submit_btn.click(
generate_character,
inputs=[world_description, persona_description],
outputs=[character_json],
).then(
infer_flux,
inputs=[character_json],
outputs=[character_image],
)
random_world_btn.click(
get_random_world_description,
outputs=[world_description],
)
random_persona_btn.click(
get_random_persona_description,
outputs=[persona_description],
)
demo.queue().launch(share=False)
|