Jensin commited on
Commit
745adc6
Β·
verified Β·
1 Parent(s): f52b5b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -37
app.py CHANGED
@@ -1,4 +1,5 @@
1
- # app.py ──────────────────────────────────────────────────────────────────────
 
2
  from datasets import load_dataset
3
  import gradio as gr, json, os, random, torch, spaces
4
  from diffusers import FluxPipeline, AutoencoderKL
@@ -7,10 +8,10 @@ from live_preview_helpers import (
7
  flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
8
  )
9
 
10
- # ─────────────────────────── 1. Device ─────────────────────────────────────
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # ────────────────────── 2. FLUX image pipeline ─────────────────────────────
14
  pipe = FluxPipeline.from_pretrained(
15
  "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
16
  ).to(device)
@@ -19,7 +20,7 @@ good_vae = AutoencoderKL.from_pretrained(
19
  ).to(device)
20
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
21
 
22
- # ───────────────────────── 3. LLM client (robust) ──────────────────────────
23
  LLM_SPACES = [
24
  "https://huggingfaceh4-zephyr-chat.hf.space",
25
  "meta-llama/Llama-3.3-70B-Instruct",
@@ -27,14 +28,11 @@ LLM_SPACES = [
27
  ]
28
 
29
  def first_live_space(space_ids: list[str]) -> Client:
30
- """
31
- Return the first Space whose /chat endpoint answers a 1-token echo.
32
- """
33
  for sid in space_ids:
34
  try:
35
  print(f"[info] probing {sid}")
36
  c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
37
- _ = c.predict("ping", 8, api_name="/chat") # simple health check
38
  print(f"[info] using {sid}")
39
  return c
40
  except Exception as e:
@@ -42,18 +40,12 @@ def first_live_space(space_ids: list[str]) -> Client:
42
  raise RuntimeError("No live chat Space found!")
43
 
44
  llm_client = first_live_space(LLM_SPACES)
45
- CHAT_API = "/chat" # universal endpoint for TGI-style Spaces
46
 
47
  def call_llm(prompt: str,
48
  max_tokens: int = 256,
49
  temperature: float = 0.6,
50
  top_p: float = 0.9) -> str:
51
- """
52
- Send a single-message chat to the Space. Extra sliders in the remote UI must
53
- be supplied in positional order after the prompt, so we match Zephyr/Gemma:
54
- [prompt, max_tokens, temperature, top_p, repeat_penalty, presence_penalty]
55
- We pass only the first four; the Space will fill the rest with defaults.
56
- """
57
  try:
58
  return llm_client.predict(
59
  prompt, max_tokens, temperature, top_p, api_name=CHAT_API
@@ -62,26 +54,21 @@ def call_llm(prompt: str,
62
  print(f"[error] LLM failure β†’ {exc}")
63
  return "…"
64
 
65
- # ──────────────────────── 4. Persona dataset ──────────────────────────────
66
  ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
67
 
68
  def random_persona() -> str:
69
  return ds[random.randint(0, len(ds) - 1)]["persona"]
70
 
71
- # ─────────────────────────── 5. Text prompts ───────────────────────────────
72
  PROMPT_TEMPLATE = """Generate a character with this persona description:
73
-
74
  {persona_description}
75
-
76
  In a world with this description:
77
-
78
  {world_description}
79
-
80
  Write the character in JSON with keys:
81
  name, background, appearance, personality, skills_and_abilities,
82
  goals, conflicts, backstory, current_situation,
83
  spoken_lines (list of strings).
84
-
85
  Respond with JSON only (no markdown)."""
86
 
87
  WORLD_PROMPT = (
@@ -89,12 +76,25 @@ WORLD_PROMPT = (
89
  "Respond with the description only."
90
  )
91
 
92
- # ───────────────────────── 6. Helper functions ─────────────────────────────
93
  def random_world() -> str:
94
  return call_llm(WORLD_PROMPT, max_tokens=120)
95
 
 
 
 
 
 
 
 
 
 
96
  @spaces.GPU(duration=75)
97
  def infer_flux(character_json):
 
 
 
 
98
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
99
  prompt=character_json["appearance"],
100
  guidance_scale=3.5,
@@ -109,6 +109,7 @@ def infer_flux(character_json):
109
 
110
  def generate_character(world_desc: str, persona_desc: str,
111
  progress=gr.Progress(track_tqdm=True)):
 
112
  raw = call_llm(
113
  PROMPT_TEMPLATE.format(
114
  persona_description=persona_desc,
@@ -116,25 +117,35 @@ def generate_character(world_desc: str, persona_desc: str,
116
  ),
117
  max_tokens=1024,
118
  )
119
- try:
120
- return json.loads(raw)
121
- except json.JSONDecodeError:
122
- # retry once if the model didn’t return valid JSON
123
- raw = call_llm(
124
- PROMPT_TEMPLATE.format(
125
- persona_description=persona_desc,
126
- world_description=world_desc,
127
- ),
128
- max_tokens=1024,
129
- )
130
- return json.loads(raw)
131
 
132
- # ─────────────────────────── 7. Gradio UI ──────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  DESCRIPTION = """
134
  * Generates a JSON character sheet from a world + persona.
135
  * Appearance images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
136
  * Personas sampled from **FinePersonas-Lite**.
137
-
138
  Tip β†’ Shuffle the world then persona for rapid inspiration.
139
  """
140
 
 
1
+ # app.py β€” Robust Character Generator Space
2
+
3
  from datasets import load_dataset
4
  import gradio as gr, json, os, random, torch, spaces
5
  from diffusers import FluxPipeline, AutoencoderKL
 
8
  flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
9
  )
10
 
11
+ # 1. Device
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # 2. FLUX image pipeline
15
  pipe = FluxPipeline.from_pretrained(
16
  "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
17
  ).to(device)
 
20
  ).to(device)
21
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
22
 
23
+ # 3. LLM client (robust)
24
  LLM_SPACES = [
25
  "https://huggingfaceh4-zephyr-chat.hf.space",
26
  "meta-llama/Llama-3.3-70B-Instruct",
 
28
  ]
29
 
30
  def first_live_space(space_ids: list[str]) -> Client:
 
 
 
31
  for sid in space_ids:
32
  try:
33
  print(f"[info] probing {sid}")
34
  c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
35
+ _ = c.predict("ping", 8, api_name="/chat")
36
  print(f"[info] using {sid}")
37
  return c
38
  except Exception as e:
 
40
  raise RuntimeError("No live chat Space found!")
41
 
42
  llm_client = first_live_space(LLM_SPACES)
43
+ CHAT_API = "/chat"
44
 
45
  def call_llm(prompt: str,
46
  max_tokens: int = 256,
47
  temperature: float = 0.6,
48
  top_p: float = 0.9) -> str:
 
 
 
 
 
 
49
  try:
50
  return llm_client.predict(
51
  prompt, max_tokens, temperature, top_p, api_name=CHAT_API
 
54
  print(f"[error] LLM failure β†’ {exc}")
55
  return "…"
56
 
57
+ # 4. Persona dataset
58
  ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
59
 
60
  def random_persona() -> str:
61
  return ds[random.randint(0, len(ds) - 1)]["persona"]
62
 
63
+ # 5. Text prompts
64
  PROMPT_TEMPLATE = """Generate a character with this persona description:
 
65
  {persona_description}
 
66
  In a world with this description:
 
67
  {world_description}
 
68
  Write the character in JSON with keys:
69
  name, background, appearance, personality, skills_and_abilities,
70
  goals, conflicts, backstory, current_situation,
71
  spoken_lines (list of strings).
 
72
  Respond with JSON only (no markdown)."""
73
 
74
  WORLD_PROMPT = (
 
76
  "Respond with the description only."
77
  )
78
 
79
+ # 6. Helper functions
80
  def random_world() -> str:
81
  return call_llm(WORLD_PROMPT, max_tokens=120)
82
 
83
+ def safe_json_parse(raw):
84
+ """Try to parse JSON, return None if fail, and log."""
85
+ try:
86
+ return json.loads(raw)
87
+ except Exception as e:
88
+ print(f"[ERROR] JSON parsing failed: {e}")
89
+ print(f"[DEBUG] Raw output: {raw[:1000]}")
90
+ return None
91
+
92
  @spaces.GPU(duration=75)
93
  def infer_flux(character_json):
94
+ # Defensive: If not a dict or missing appearance, bail out
95
+ if not isinstance(character_json, dict) or "appearance" not in character_json:
96
+ print("[ERROR] No valid appearance to generate image.")
97
+ return None
98
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
99
  prompt=character_json["appearance"],
100
  guidance_scale=3.5,
 
109
 
110
  def generate_character(world_desc: str, persona_desc: str,
111
  progress=gr.Progress(track_tqdm=True)):
112
+ # First attempt
113
  raw = call_llm(
114
  PROMPT_TEMPLATE.format(
115
  persona_description=persona_desc,
 
117
  ),
118
  max_tokens=1024,
119
  )
120
+ character = safe_json_parse(raw)
121
+ if character:
122
+ return character
 
 
 
 
 
 
 
 
 
123
 
124
+ # Retry once
125
+ raw2 = call_llm(
126
+ PROMPT_TEMPLATE.format(
127
+ persona_description=persona_desc,
128
+ world_description=world_desc,
129
+ ),
130
+ max_tokens=1024,
131
+ )
132
+ character2 = safe_json_parse(raw2)
133
+ if character2:
134
+ return character2
135
+
136
+ # If both fail, return error and raw outputs for debugging
137
+ return {
138
+ "error": "LLM did not return valid JSON after 2 attempts.",
139
+ "first_raw": raw,
140
+ "second_raw": raw2,
141
+ "tip": "Check your LLM prompt and output. Try regenerating.",
142
+ }
143
+
144
+ # 7. Gradio UI
145
  DESCRIPTION = """
146
  * Generates a JSON character sheet from a world + persona.
147
  * Appearance images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
148
  * Personas sampled from **FinePersonas-Lite**.
 
149
  Tip β†’ Shuffle the world then persona for rapid inspiration.
150
  """
151