Jensin commited on
Commit
9c89ed0
Β·
verified Β·
1 Parent(s): c92c8f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -85
app.py CHANGED
@@ -1,48 +1,63 @@
1
- # app.py
2
  from datasets import load_dataset
3
  import gradio as gr
4
  from gradio_client import Client
5
- import json, os, random, torch
6
  from diffusers import FluxPipeline, AutoencoderKL
7
  from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
8
- import spaces
9
 
10
- # ─────────────────────────────── 1. Device ────────────────────────────────
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # ─────────────────────── 2. Image / FLUX pipeline ─────────────────────────
14
  pipe = FluxPipeline.from_pretrained(
15
- "black-forest-labs/FLUX.1-dev",
16
- torch_dtype=torch.bfloat16
17
  ).to(device)
18
  good_vae = AutoencoderKL.from_pretrained(
19
- "black-forest-labs/FLUX.1-dev",
20
- subfolder="vae",
21
- torch_dtype=torch.bfloat16
22
  ).to(device)
23
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = (
24
  flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
25
  )
26
 
27
- # ───────────────────────── 3. LLM (Zephyr-chat) ───────────────────────────
28
- llm_client = Client("HuggingFaceH4/zephyr-chat") # public Space
29
- CHAT_API = llm_client.view_api()[0]["api_name"] # e.g. "/chat"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def call_llm(
32
  user_prompt: str,
33
- system_prompt: str = "You are Zephyr, a helpful and creative assistant.",
34
  history: list | None = None,
35
  temperature: float = 0.7,
36
  top_p: float = 0.9,
37
  max_tokens: int = 1024,
38
  ) -> str:
39
  """
40
- Robust wrapper around the Zephyr chat Space.
41
- Falls back to '...' on any error so the Gradio UI never crashes.
42
  """
43
  history = history or []
44
  try:
45
- # Zephyr-chat expects: prompt, system_prompt, history, temperature, top_p, max_new_tokens
46
  result = llm_client.predict(
47
  user_prompt,
48
  system_prompt,
@@ -52,21 +67,22 @@ def call_llm(
52
  max_tokens,
53
  api_name=CHAT_API,
54
  )
55
- # Some Spaces return a plain string, others return the old tuple format.
56
- return result.strip() if isinstance(result, str) else result[1][0][-1].strip()
 
 
57
  except Exception as e:
58
- print(f"[LLM error] {e}")
59
- return "..."
60
 
61
- # ───────────────────────── 4. Persona dataset ────────────────────────────
62
  ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
63
 
64
- def get_random_persona_description() -> str:
65
- idx = random.randint(0, len(ds) - 1)
66
- return ds[idx]["persona"]
67
 
68
- # ─────────────────────────── 5. Prompts ─────────────────────────────────
69
- prompt_template = """Generate a character with this persona description:
70
 
71
  {persona_description}
72
 
@@ -74,25 +90,25 @@ In a world with this description:
74
 
75
  {world_description}
76
 
77
- Write the character in JSON format with these keys:
78
- name, background, appearance, personality, skills_and_abilities, goals,
79
- conflicts, backstory, current_situation, spoken_lines (list of strings).
 
80
 
81
- Respond with **only** the JSON (no markdown, no fencing)."""
82
 
83
- world_description_prompt = (
84
  "Invent a short, unique and vivid world description. "
85
  "Respond with the description only."
86
  )
87
 
88
- # ─────────────────────── 6. Gradio helper funcs ─────────────────────────
89
- def get_random_world_description() -> str:
90
- return call_llm(world_description_prompt)
91
 
92
  @spaces.GPU(duration=75)
93
  def infer_flux(character_json):
94
- """Stream intermediate images while FLUX denoises."""
95
- for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
96
  prompt=character_json["appearance"],
97
  guidance_scale=3.5,
98
  num_inference_steps=28,
@@ -102,82 +118,66 @@ def infer_flux(character_json):
102
  output_type="pil",
103
  good_vae=good_vae,
104
  ):
105
- yield image
106
 
107
- def generate_character(world_description: str,
108
- persona_description: str,
109
  progress=gr.Progress(track_tqdm=True)):
110
  raw = call_llm(
111
- prompt_template.format(
112
- persona_description=persona_description,
113
- world_description=world_description,
114
  ),
115
  max_tokens=1024,
116
  )
117
  try:
118
  return json.loads(raw)
119
  except json.JSONDecodeError:
120
- # One retry if the LLM hallucinated
121
  raw = call_llm(
122
- prompt_template.format(
123
- persona_description=persona_description,
124
- world_description=world_description,
125
  ),
126
  max_tokens=1024,
127
  )
128
  return json.loads(raw)
129
 
130
- # ───────────────────────────── 7. UI ─────────────────────────────────────
131
- app_description = """
132
- - Generates a character profile (JSON) from a world + persona description.
133
- - **Appearance** images come from [FLUX-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
134
- - **Back-stories** come from [Zephyr-7B-Ξ²](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta).
135
- - Personas are sampled from [FinePersonas-Lite](https://huggingface.co/datasets/MohamedRashad/FinePersonas-Lite).
136
 
137
- Tip β†’ Write or randomise a world, then spin the persona box to see how the same
138
- world shapes different heroes.
139
  """
140
 
141
  with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
142
- gr.Markdown("<h1 style='text-align:center'>πŸ§™β€β™€οΈ Character Generator</h1>")
143
- gr.Markdown(app_description.strip())
144
 
145
  with gr.Row():
146
- world_description = gr.Textbox(label="World Description", lines=10, scale=4)
147
- persona_description = gr.Textbox(
148
- label="Persona Description",
149
- value=get_random_persona_description(),
150
- lines=10,
151
- scale=1,
152
  )
153
 
154
  with gr.Row():
155
- random_world_btn = gr.Button("πŸ”„ Random World", variant="secondary")
156
- submit_btn = gr.Button("✨ Generate Character", variant="primary", scale=5)
157
- random_persona_btn = gr.Button("πŸ”„ Random Persona", variant="secondary")
158
 
159
  with gr.Row():
160
- character_image = gr.Image(label="Character Image")
161
- character_json = gr.JSON(label="Character Description")
162
-
163
- # Hooks
164
- submit_btn.click(
165
- generate_character,
166
- inputs=[world_description, persona_description],
167
- outputs=[character_json],
168
  ).then(
169
- infer_flux,
170
- inputs=[character_json],
171
- outputs=[character_image],
172
  )
173
 
174
- random_world_btn.click(
175
- get_random_world_description,
176
- outputs=[world_description],
177
- )
178
- random_persona_btn.click(
179
- get_random_persona_description,
180
- outputs=[persona_description],
181
- )
182
 
183
  demo.queue().launch(share=False)
 
 
 
1
  from datasets import load_dataset
2
  import gradio as gr
3
  from gradio_client import Client
4
+ import json, os, random, torch, spaces
5
  from diffusers import FluxPipeline, AutoencoderKL
6
  from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
 
7
 
8
+ # ───────────────────────────── 1. Device ────────────────────────────────────
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # ─────────────────────── 2. Image / FLUX pipeline ───────────────────────────
12
  pipe = FluxPipeline.from_pretrained(
13
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
 
14
  ).to(device)
15
  good_vae = AutoencoderKL.from_pretrained(
16
+ "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
 
 
17
  ).to(device)
18
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = (
19
  flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
20
  )
21
 
22
+ # ───────────────────────── 3. LLM client (robust) ───────────────────────────
23
+ def _first_working_client(candidates: list[str]) -> Client:
24
+ """
25
+ Try a list of Space URLs / repo-ids, return the first that gives a JSON config.
26
+ """
27
+ for src in candidates:
28
+ try:
29
+ print(f"[info] Trying LLM Space: {src}")
30
+ c = Client(src, hf_token=os.getenv("HF_TOKEN")) # token optional
31
+ # If this passes, the config was parsed as JSON
32
+ c.view_api()
33
+ print(f"[info] Selected LLM Space: {src}")
34
+ return c
35
+ except Exception as e:
36
+ print(f"[warn] {src} not usable β†’ {e}")
37
+ raise RuntimeError("No usable LLM Space found!")
38
+
39
+ LLM_CANDIDATES = [
40
+ "https://huggingfaceh4-zephyr-chat.hf.space", # direct URL
41
+ "HuggingFaceH4/zephyr-chat", # repo slug
42
+ "huggingface-projects/gemma-2-9b-it", # fallback Space
43
+ ]
44
+
45
+ llm_client = _first_working_client(LLM_CANDIDATES)
46
+ CHAT_API = llm_client.view_api()[0]["api_name"] # safest way to get endpoint
47
 
48
  def call_llm(
49
  user_prompt: str,
50
+ system_prompt: str = "You are a helpful creative assistant.",
51
  history: list | None = None,
52
  temperature: float = 0.7,
53
  top_p: float = 0.9,
54
  max_tokens: int = 1024,
55
  ) -> str:
56
  """
57
+ Unified chat wrapper – works for both Zephyr and Gemma Spaces.
 
58
  """
59
  history = history or []
60
  try:
 
61
  result = llm_client.predict(
62
  user_prompt,
63
  system_prompt,
 
67
  max_tokens,
68
  api_name=CHAT_API,
69
  )
70
+ # Some Spaces return string, some return (…, history) tuple
71
+ if isinstance(result, str):
72
+ return result.strip()
73
+ return result[1][0][-1].strip()
74
  except Exception as e:
75
+ print(f"[error] LLM call failed β†’ {e}")
76
+ return "…"
77
 
78
+ # ───────────────────────── 4. Persona dataset ───────────────────────────────
79
  ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
80
 
81
+ def random_persona() -> str:
82
+ return ds[random.randint(0, len(ds) - 1)]["persona"]
 
83
 
84
+ # ─────────────────────────── 5. Prompt templates ───────────────────────────
85
+ PROMPT_TEMPLATE = """Generate a character with this persona description:
86
 
87
  {persona_description}
88
 
 
90
 
91
  {world_description}
92
 
93
+ Write the character in JSON with keys:
94
+ name, background, appearance, personality, skills_and_abilities,
95
+ goals, conflicts, backstory, current_situation,
96
+ spoken_lines (list of strings).
97
 
98
+ Respond with JSON only (no markdown)."""
99
 
100
+ WORLD_PROMPT = (
101
  "Invent a short, unique and vivid world description. "
102
  "Respond with the description only."
103
  )
104
 
105
+ # ─────────────────────── 6. Helper functions ───────────────────────────────
106
+ def random_world() -> str:
107
+ return call_llm(WORLD_PROMPT)
108
 
109
  @spaces.GPU(duration=75)
110
  def infer_flux(character_json):
111
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
 
112
  prompt=character_json["appearance"],
113
  guidance_scale=3.5,
114
  num_inference_steps=28,
 
118
  output_type="pil",
119
  good_vae=good_vae,
120
  ):
121
+ yield img
122
 
123
+ def generate_character(world_desc: str, persona_desc: str,
 
124
  progress=gr.Progress(track_tqdm=True)):
125
  raw = call_llm(
126
+ PROMPT_TEMPLATE.format(
127
+ persona_description=persona_desc,
128
+ world_description=world_desc,
129
  ),
130
  max_tokens=1024,
131
  )
132
  try:
133
  return json.loads(raw)
134
  except json.JSONDecodeError:
135
+ # One retry
136
  raw = call_llm(
137
+ PROMPT_TEMPLATE.format(
138
+ persona_description=persona_desc,
139
+ world_description=world_desc,
140
  ),
141
  max_tokens=1024,
142
  )
143
  return json.loads(raw)
144
 
145
+ # ───────────────────────────── 7. UI ────────────────────────────────────────
146
+ DESCRIPTION = """
147
+ * Generates a character sheet (JSON) from a world + persona.
148
+ * Appearance images via **FLUX-dev**; narrative via **Zephyr-chat** (or Gemma fallback).
149
+ * Personas come from **FinePersonas-Lite**.
 
150
 
151
+ Tip β†’ Spin the world, then shuffle personas to see very different heroes.
 
152
  """
153
 
154
  with gr.Blocks(title="Character Generator", theme="Nymbo/Nymbo_Theme") as demo:
155
+ gr.Markdown("<h1 style='text-align:center'>πŸ§šβ€β™€οΈ Character Generator</h1>")
156
+ gr.Markdown(DESCRIPTION.strip())
157
 
158
  with gr.Row():
159
+ world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
160
+ persona_tb = gr.Textbox(
161
+ label="Persona Description", value=random_persona(), lines=10, scale=1
 
 
 
162
  )
163
 
164
  with gr.Row():
165
+ btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
166
+ btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
167
+ btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
168
 
169
  with gr.Row():
170
+ img_out = gr.Image(label="Character Image")
171
+ json_out = gr.JSON(label="Character Description")
172
+
173
+ btn_generate.click(
174
+ generate_character, [world_tb, persona_tb], [json_out]
 
 
 
175
  ).then(
176
+ infer_flux, [json_out], [img_out]
 
 
177
  )
178
 
179
+ btn_world.click(random_world, outputs=[world_tb])
180
+ btn_persona.click(random_persona, outputs=[persona_tb])
 
 
 
 
 
 
181
 
182
  demo.queue().launch(share=False)
183
+