Spaces:
Running
Running
File size: 9,882 Bytes
b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 a210442 b7304c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
#!/usr/bin/env python3
"""
what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
A slow, contemplative global guessing game.
🔮 HOW IT WORKS 🔮
• A single Llama‑3.1‑8B‑Instruct model (FP32 on CPU) is generating one very long completion
for a chosen mystical prompt. It runs continuously in the background for everyone.
• Any visitor sees the same prompt and the Oracle’s current partial response.
• Players may submit *one* of two kinds of guesses:
1. 🧠 **Exact Completion** – the full sentence/paragraph they think the Oracle will
eventually write.
2. 💡 **General Idea** – a short summary of the direction or theme they expect.
• Each guess is recorded immediately (with timestamp, Oracle progress, etc.) to
`data.json` (JSON‑Lines). When the Oracle finally finishes, offline evaluation can
score the guesses against the final text.
The game then moves on to the next prompt and the cycle repeats.
"""
import os
import json
import time
import random
import threading
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr
###############################################################################
# Settings #
###############################################################################
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # FP32, CPU‑only
PROMPTS_PATH = "oracle_prompts.json" # 100 unfinished lines
STATE_PATH = "current_state.json" # persistent Oracle state
DATA_PATH = "data.json" # JSONL of user guesses
TOKENS_PER_PROMPT = 2048 # stop after N tokens
SECS_BETWEEN_TOKENS = 15 # pacing (≈10h / prompt)
TEMPERATURE = 0.8
TOP_P = 0.95
MAX_CONTEXT_TOKENS = 8192
###############################################################################
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
log = logging.getLogger("what‑comes‑next")
lock = threading.Lock() # global file/variable lock
# --------------------------------------------------------------------------- #
# Helper functions #
# --------------------------------------------------------------------------- #
def _read_json(path: str, default: Any):
try:
with open(path, "r", encoding="utf‑8") as f:
return json.load(f)
except FileNotFoundError:
return default
def _write_json(path: str, obj: Any):
tmp = f"{path}.tmp"
with open(tmp, "w", encoding="utf‑8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2)
os.replace(tmp, path)
def load_prompts() -> list[str]:
if not os.path.exists(PROMPTS_PATH):
raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
with open(PROMPTS_PATH, "r", encoding="utf‑8") as f:
return json.load(f)
prompts = load_prompts()
# --------------------------------------------------------------------------- #
# Model loading (FP32 ‑ CPU) #
# --------------------------------------------------------------------------- #
log.info("Loading Llama‑3.1‑8B‑Instruct in FP32 on CPU (this is *slow*) …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map={"": "cpu"}, # force CPU
)
model.eval()
log.info("Model loaded.")
# --------------------------------------------------------------------------- #
# Oracle generation thread #
# --------------------------------------------------------------------------- #
def init_state() -> Dict[str, Any]:
"""Return existing state or create a new one."""
state = _read_json(STATE_PATH, {})
if state.get("finished", False):
state = {} # finished, start new prompt
if not state:
prompt_idx = random.randrange(len(prompts))
prompt = prompts[prompt_idx]
state = {
"prompt_idx": prompt_idx,
"prompt": prompt,
"generated": "", # Oracle’s text so far (string)
"start_time": time.time(),
"finished": False,
"tokens_done": 0
}
_write_json(STATE_PATH, state)
log.info(f"Starting new Oracle prompt #{prompt_idx}: {prompt[:60]}…")
return state
def oracle_loop():
"""Continuously extend the Oracle’s text by one token every SECS_BETWEEN_TOKENS."""
while True:
with lock:
state = init_state()
if state["finished"]:
# Should not happen, but guard anyway
time.sleep(SECS_BETWEEN_TOKENS)
continue
prompt_text = state["prompt"]
generated_text = state["generated"]
tokens_done = state["tokens_done"]
# Build input_ids (prompt + generated so far)
full_input = prompt_text + generated_text
input_ids = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
# Generate ONE token
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=1,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
next_token_id = outputs[0, -1].unsqueeze(0)
next_token_text = tokenizer.decode(next_token_id, skip_special_tokens=True, clean_up_tokenization_spaces=False)
with lock:
# Update state
state["generated"] += next_token_text
state["tokens_done"] += 1
if state["tokens_done"] >= TOKENS_PER_PROMPT:
state["finished"] = True
log.info("Prompt complete. Oracle will pick a new one next cycle.")
_write_json(STATE_PATH, state)
time.sleep(SECS_BETWEEN_TOKENS) # pacing
threading.Thread(target=oracle_loop, daemon=True).start()
# --------------------------------------------------------------------------- #
# Gradio Interface #
# --------------------------------------------------------------------------- #
def human_readable_elapsed(start: float) -> str:
delta = int(time.time() - start)
h, rem = divmod(delta, 3600)
m, s = divmod(rem, 60)
return f"{h}h {m}m {s}s"
def get_current_state() -> Dict[str, Any]:
with lock:
state = _read_json(STATE_PATH, {})
if not state:
return {"prompt": "…loading…", "generated": "", "elapsed": "0h 0m 0s"}
return {
"prompt": state["prompt"],
"generated": state["generated"],
"elapsed": human_readable_elapsed(state["start_time"])
}
def record_guess(full_guess: str, idea_guess: str):
state = get_current_state()
guess_text = full_guess.strip() or idea_guess.strip()
if not guess_text:
return gr.update(value="⚠️ Please enter a guess in one of the boxes …"), gr.update()
guess_type = "full" if full_guess.strip() else "idea"
record = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"prompt": state["prompt"],
"point‑in‑time": state["elapsed"],
"response‑point": state["generated"],
"user‑guess": guess_text,
"guess‑type": guess_type
}
# Append to JSONL (data.json)
with lock:
with open(DATA_PATH, "a", encoding="utf‑8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
log.info(f"Recorded {guess_type} guess ({len(guess_text)} chars).")
return gr.update(value="✅ Guess recorded – check back when the Oracle finishes!"), gr.update(value="")
with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
gr.Markdown("""# ✨ What Comes Next
A global, slow‑burn guessing game. The Oracle is continuously writing its story.
Read the prompt, see the Oracle’s progress, and predict **what comes next**!
*(FP32 CPU inference – deliberately unhurried.)*""")
### Live Oracle view
prompt_box = gr.Markdown(label="🔮 Current Oracle Prompt")
oracle_box = gr.Textbox(label="📜 Oracle’s current text", lines=10, interactive=False)
elapsed_box = gr.Textbox(label="⏱️ Elapsed", interactive=False)
### Guess inputs
gr.Markdown("**Make your prediction:** Fill **either** the exact continuation *or* a general idea.")
with gr.Row():
full_guess = gr.Textbox(label="🧠 Exact continuation (full)")
idea_guess = gr.Textbox(label="💡 General idea")
submit_btn = gr.Button("Submit Guess")
status_msg = gr.Textbox(label="Status", interactive=False)
### Refresh button
refresh_btn = gr.Button("🔄 Refresh Oracle progress")
def refresh():
st = get_current_state()
return st["prompt"], st["generated"], st["elapsed"]
refresh_btn.click(refresh, outputs=[prompt_box, oracle_box, elapsed_box])
demo.load(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) # auto‑load on launch
submit_btn.click(record_guess,
inputs=[full_guess, idea_guess],
outputs=[status_msg, full_guess]) # clear full_guess box on success
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|