#!/usr/bin/env python3 """ what_comes_next.py – Hugging Face Space implementation of **What Comes Next** A global, slow-burn guessing game powered by Llama-3.1-8B-Instruct (FP32, CPU-only). HOW IT WORKS ============ β€’ One shared model generates a single, very long completion (β‰ˆ2 k tokens) for a chosen prompt in *full precision* on CPU. One token is sampled every ~15 s, so a prompt unfolds for roughly 10 hours. All visitors see the same progress in real-time. β€’ Players read the partial output and may submit **either** 🧠 Exact continuation (full guess) **or** πŸ’‘ General idea (summary guess). β€’ Each guess is appended to `data.json` with prompt, Oracle progress, timestamp & type. β€’ Offline scoring (not included here) can later measure similarity vs the final text. """ from __future__ import annotations import os import json import time import random import threading import logging from datetime import datetime, timezone from typing import Dict, Any import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM ############################################################################### # Configuration # ############################################################################### MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # full-precision model PROMPTS_PATH = "full_prompts.json" # 100 full prompts STATE_PATH = "current_state.json" # persistent Oracle state DATA_PATH = "data.json" # JSONL log of guesses TOKENS_PER_PROMPT = 2048 # stop after N generated tokens SECS_BETWEEN_TOKENS = 15 # ~10 h per prompt TEMPERATURE = 0.9 # higher creativity, as requested TOP_P = 0.95 # nucleus sampling MAX_CONTEXT_TOKENS = 8192 # safety cap ############################################################################### logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) log = logging.getLogger("what-comes-next") ############################################################################### # Utility helpers # ############################################################################### def _read_json(path: str, default: Any): try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: return default def _atomic_write(path: str, obj: Any): tmp = f"{path}.tmp" with open(tmp, "w", encoding="utf-8") as f: json.dump(obj, f, ensure_ascii=False, indent=2) os.replace(tmp, path) def load_prompts() -> list[str]: if not os.path.exists(PROMPTS_PATH): raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.") with open(PROMPTS_PATH, "r", encoding="utf-8") as f: prompts = json.load(f) if not isinstance(prompts, list) or not prompts: raise ValueError("full_prompts.json must be a non-empty JSON array of strings") return prompts ############################################################################### # Model loading # ############################################################################### log.info("Loading Llama-3.1-8B-Instruct (FP32 CPU-only)… this can take a while.") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map={"": "cpu"}, # force CPU placement ) model.eval() log.info("Model ready – Oracle awakened.") ############################################################################### # Global state # ############################################################################### lock = threading.Lock() # guard state + files prompts = load_prompts() # list of 100 strings ############################################################################### # Oracle generation thread # ############################################################################### def _init_state() -> Dict[str, Any]: """Return existing state or create a fresh one if none/finished.""" state = _read_json(STATE_PATH, {}) if not state or state.get("finished"): prompt_idx = random.randrange(len(prompts)) state = { "prompt_idx": prompt_idx, "prompt": prompts[prompt_idx], "generated": "", # text so far "tokens_done": 0, "start_time": time.time(), "finished": False } _atomic_write(STATE_PATH, state) log.info(f"New Oracle prompt #{prompt_idx}: {state['prompt'][:80]}…") return state def _elapsed_str(start: float) -> str: d = int(time.time() - start) h, r = divmod(d, 3600) m, s = divmod(r, 60) return f"{h}h {m}m {s}s" def oracle_loop(): while True: with lock: state = _init_state() if state["finished"]: time.sleep(SECS_BETWEEN_TOKENS) continue # Build context: prompt + generated so far context = state["prompt"] + state["generated"] input_ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids # Sample one token with torch.no_grad(): out = model.generate( input_ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, ) next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False) with lock: state["generated"] += next_token state["tokens_done"] += 1 if state["tokens_done"] >= TOKENS_PER_PROMPT: state["finished"] = True log.info("Prompt completed – Oracle will select a new one shortly.") _atomic_write(STATE_PATH, state) time.sleep(SECS_BETWEEN_TOKENS) threading.Thread(target=oracle_loop, daemon=True).start() ############################################################################### # Gradio interface # ############################################################################### def fetch_state() -> tuple[str, str, str]: state = _read_json(STATE_PATH, {}) if not state: return "Loading…", "", "0h 0m 0s" return state["prompt"], state["generated"], _elapsed_str(state["start_time"]) def submit_guess(full: str, idea: str): full = full.strip() idea = idea.strip() if not full and not idea: return gr.update(value="⚠️ Enter a guess in one of the fields."), gr.update(), gr.update() prompt, generated, elapsed = fetch_state() guess_text = full or idea guess_type = "full" if full else "idea" record = { "timestamp": datetime.now(timezone.utc).isoformat(), "prompt": prompt, "point-in-time": elapsed, "response-point": generated, "user-guess": guess_text, "guess-type": guess_type } with lock: with open(DATA_PATH, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") log.info(f"Logged {guess_type} guess ({len(guess_text)} chars).") return gr.update(value="βœ… Guess recorded – thanks!"), gr.update(value=""), gr.update(value="") with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo: gr.Markdown("""# 🌌 What Comes Next Watch the Oracle craft an extended response – **one token at a time**. Predict its next words or general direction and see how close you were when the tale concludes. (All inputs are stored in `data.json` for research.)""") prompt_md = gr.Markdown() oracle_box = gr.Textbox(lines=10, interactive=False, label="πŸ“œ Oracle text so far") elapsed_tb = gr.Textbox(interactive=False, label="⏱ Elapsed time") refresh_btn = gr.Button("πŸ”„ Refresh") with gr.Row(): exact_tb = gr.Textbox(label="🧠 Exact continuation (full)") idea_tb = gr.Textbox(label="πŸ’‘ General idea") submit_btn = gr.Button("Submit Guess") status_tb = gr.Textbox(interactive=False, label="Status") # Actions refresh_btn.click(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb]) demo.load(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb]) submit_btn.click(submit_guess, inputs=[exact_tb, idea_tb], outputs=[status_tb, exact_tb, idea_tb]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)