Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
what_comes_next.py – Hugging Face Space implementation of **What Comes Next** | |
A slow, contemplative global guessing game. | |
🔮 HOW IT WORKS 🔮 | |
• A single Llama‑3.1‑8B‑Instruct model (FP32 on CPU) is generating one very long completion | |
for a chosen mystical prompt. It runs continuously in the background for everyone. | |
• Any visitor sees the same prompt and the Oracle’s current partial response. | |
• Players may submit *one* of two kinds of guesses: | |
1. 🧠 **Exact Completion** – the full sentence/paragraph they think the Oracle will | |
eventually write. | |
2. 💡 **General Idea** – a short summary of the direction or theme they expect. | |
• Each guess is recorded immediately (with timestamp, Oracle progress, etc.) to | |
`data.json` (JSON‑Lines). When the Oracle finally finishes, offline evaluation can | |
score the guesses against the final text. | |
The game then moves on to the next prompt and the cycle repeats. | |
""" | |
import os | |
import json | |
import time | |
import random | |
import threading | |
import logging | |
from datetime import datetime, timezone | |
from pathlib import Path | |
from typing import Dict, Any | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import gradio as gr | |
############################################################################### | |
# Settings # | |
############################################################################### | |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # FP32, CPU‑only | |
PROMPTS_PATH = "oracle_prompts.json" # 100 unfinished lines | |
STATE_PATH = "current_state.json" # persistent Oracle state | |
DATA_PATH = "data.json" # JSONL of user guesses | |
TOKENS_PER_PROMPT = 2048 # stop after N tokens | |
SECS_BETWEEN_TOKENS = 15 # pacing (≈10h / prompt) | |
TEMPERATURE = 0.8 | |
TOP_P = 0.95 | |
MAX_CONTEXT_TOKENS = 8192 | |
############################################################################### | |
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) | |
log = logging.getLogger("what‑comes‑next") | |
lock = threading.Lock() # global file/variable lock | |
# --------------------------------------------------------------------------- # | |
# Helper functions # | |
# --------------------------------------------------------------------------- # | |
def _read_json(path: str, default: Any): | |
try: | |
with open(path, "r", encoding="utf‑8") as f: | |
return json.load(f) | |
except FileNotFoundError: | |
return default | |
def _write_json(path: str, obj: Any): | |
tmp = f"{path}.tmp" | |
with open(tmp, "w", encoding="utf‑8") as f: | |
json.dump(obj, f, ensure_ascii=False, indent=2) | |
os.replace(tmp, path) | |
def load_prompts() -> list[str]: | |
if not os.path.exists(PROMPTS_PATH): | |
raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.") | |
with open(PROMPTS_PATH, "r", encoding="utf‑8") as f: | |
return json.load(f) | |
prompts = load_prompts() | |
# --------------------------------------------------------------------------- # | |
# Model loading (FP32 ‑ CPU) # | |
# --------------------------------------------------------------------------- # | |
log.info("Loading Llama‑3.1‑8B‑Instruct in FP32 on CPU (this is *slow*) …") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float32, | |
device_map={"": "cpu"}, # force CPU | |
) | |
model.eval() | |
log.info("Model loaded.") | |
# --------------------------------------------------------------------------- # | |
# Oracle generation thread # | |
# --------------------------------------------------------------------------- # | |
def init_state() -> Dict[str, Any]: | |
"""Return existing state or create a new one.""" | |
state = _read_json(STATE_PATH, {}) | |
if state.get("finished", False): | |
state = {} # finished, start new prompt | |
if not state: | |
prompt_idx = random.randrange(len(prompts)) | |
prompt = prompts[prompt_idx] | |
state = { | |
"prompt_idx": prompt_idx, | |
"prompt": prompt, | |
"generated": "", # Oracle’s text so far (string) | |
"start_time": time.time(), | |
"finished": False, | |
"tokens_done": 0 | |
} | |
_write_json(STATE_PATH, state) | |
log.info(f"Starting new Oracle prompt #{prompt_idx}: {prompt[:60]}…") | |
return state | |
def oracle_loop(): | |
"""Continuously extend the Oracle’s text by one token every SECS_BETWEEN_TOKENS.""" | |
while True: | |
with lock: | |
state = init_state() | |
if state["finished"]: | |
# Should not happen, but guard anyway | |
time.sleep(SECS_BETWEEN_TOKENS) | |
continue | |
prompt_text = state["prompt"] | |
generated_text = state["generated"] | |
tokens_done = state["tokens_done"] | |
# Build input_ids (prompt + generated so far) | |
full_input = prompt_text + generated_text | |
input_ids = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids | |
# Generate ONE token | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids, | |
max_new_tokens=1, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
top_p=TOP_P, | |
) | |
next_token_id = outputs[0, -1].unsqueeze(0) | |
next_token_text = tokenizer.decode(next_token_id, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
with lock: | |
# Update state | |
state["generated"] += next_token_text | |
state["tokens_done"] += 1 | |
if state["tokens_done"] >= TOKENS_PER_PROMPT: | |
state["finished"] = True | |
log.info("Prompt complete. Oracle will pick a new one next cycle.") | |
_write_json(STATE_PATH, state) | |
time.sleep(SECS_BETWEEN_TOKENS) # pacing | |
threading.Thread(target=oracle_loop, daemon=True).start() | |
# --------------------------------------------------------------------------- # | |
# Gradio Interface # | |
# --------------------------------------------------------------------------- # | |
def human_readable_elapsed(start: float) -> str: | |
delta = int(time.time() - start) | |
h, rem = divmod(delta, 3600) | |
m, s = divmod(rem, 60) | |
return f"{h}h {m}m {s}s" | |
def get_current_state() -> Dict[str, Any]: | |
with lock: | |
state = _read_json(STATE_PATH, {}) | |
if not state: | |
return {"prompt": "…loading…", "generated": "", "elapsed": "0h 0m 0s"} | |
return { | |
"prompt": state["prompt"], | |
"generated": state["generated"], | |
"elapsed": human_readable_elapsed(state["start_time"]) | |
} | |
def record_guess(full_guess: str, idea_guess: str): | |
state = get_current_state() | |
guess_text = full_guess.strip() or idea_guess.strip() | |
if not guess_text: | |
return gr.update(value="⚠️ Please enter a guess in one of the boxes …"), gr.update() | |
guess_type = "full" if full_guess.strip() else "idea" | |
record = { | |
"timestamp": datetime.now(timezone.utc).isoformat(), | |
"prompt": state["prompt"], | |
"point‑in‑time": state["elapsed"], | |
"response‑point": state["generated"], | |
"user‑guess": guess_text, | |
"guess‑type": guess_type | |
} | |
# Append to JSONL (data.json) | |
with lock: | |
with open(DATA_PATH, "a", encoding="utf‑8") as f: | |
f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
log.info(f"Recorded {guess_type} guess ({len(guess_text)} chars).") | |
return gr.update(value="✅ Guess recorded – check back when the Oracle finishes!"), gr.update(value="") | |
with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo: | |
gr.Markdown("""# ✨ What Comes Next | |
A global, slow‑burn guessing game. The Oracle is continuously writing its story. | |
Read the prompt, see the Oracle’s progress, and predict **what comes next**! | |
*(FP32 CPU inference – deliberately unhurried.)*""") | |
### Live Oracle view | |
prompt_box = gr.Markdown(label="🔮 Current Oracle Prompt") | |
oracle_box = gr.Textbox(label="📜 Oracle’s current text", lines=10, interactive=False) | |
elapsed_box = gr.Textbox(label="⏱️ Elapsed", interactive=False) | |
### Guess inputs | |
gr.Markdown("**Make your prediction:** Fill **either** the exact continuation *or* a general idea.") | |
with gr.Row(): | |
full_guess = gr.Textbox(label="🧠 Exact continuation (full)") | |
idea_guess = gr.Textbox(label="💡 General idea") | |
submit_btn = gr.Button("Submit Guess") | |
status_msg = gr.Textbox(label="Status", interactive=False) | |
### Refresh button | |
refresh_btn = gr.Button("🔄 Refresh Oracle progress") | |
def refresh(): | |
st = get_current_state() | |
return st["prompt"], st["generated"], st["elapsed"] | |
refresh_btn.click(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) | |
demo.load(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) # auto‑load on launch | |
submit_btn.click(record_guess, | |
inputs=[full_guess, idea_guess], | |
outputs=[status_msg, full_guess]) # clear full_guess box on success | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |