Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
what_comes_next.py β Hugging Face Space implementation of **What Comes Next** | |
A global, slow-burn guessing game powered by Llama-3.1-8B-Instruct (FP32, CPU-only). | |
HOW IT WORKS | |
============ | |
β’ One shared model generates a single, very long completion (β2 k tokens) for a chosen | |
prompt in *full precision* on CPU. One token is sampled every ~15 s, so a prompt | |
unfolds for roughly 10 hours. All visitors see the same progress in real-time. | |
β’ Players read the partial output and may submit **either** | |
π§ Exact continuation (full guess) **or** π‘ General idea (summary guess). | |
β’ Each guess is appended to `data.json` with prompt, Oracle progress, timestamp & type. | |
β’ Offline scoring (not included here) can later measure similarity vs the final text. | |
""" | |
from __future__ import annotations | |
import os | |
import json | |
import time | |
import random | |
import threading | |
import logging | |
from datetime import datetime, timezone | |
from typing import Dict, Any | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
############################################################################### | |
# Configuration # | |
############################################################################### | |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # full-precision model | |
PROMPTS_PATH = "full_prompts.json" # 100 full prompts | |
STATE_PATH = "current_state.json" # persistent Oracle state | |
DATA_PATH = "data.json" # JSONL log of guesses | |
TOKENS_PER_PROMPT = 2048 # stop after N generated tokens | |
SECS_BETWEEN_TOKENS = 15 # ~10 h per prompt | |
TEMPERATURE = 0.9 # higher creativity, as requested | |
TOP_P = 0.95 # nucleus sampling | |
MAX_CONTEXT_TOKENS = 8192 # safety cap | |
############################################################################### | |
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) | |
log = logging.getLogger("what-comes-next") | |
############################################################################### | |
# Utility helpers # | |
############################################################################### | |
def _read_json(path: str, default: Any): | |
try: | |
with open(path, "r", encoding="utf-8") as f: | |
return json.load(f) | |
except FileNotFoundError: | |
return default | |
def _atomic_write(path: str, obj: Any): | |
tmp = f"{path}.tmp" | |
with open(tmp, "w", encoding="utf-8") as f: | |
json.dump(obj, f, ensure_ascii=False, indent=2) | |
os.replace(tmp, path) | |
def load_prompts() -> list[str]: | |
if not os.path.exists(PROMPTS_PATH): | |
raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.") | |
with open(PROMPTS_PATH, "r", encoding="utf-8") as f: | |
prompts = json.load(f) | |
if not isinstance(prompts, list) or not prompts: | |
raise ValueError("full_prompts.json must be a non-empty JSON array of strings") | |
return prompts | |
############################################################################### | |
# Model loading # | |
############################################################################### | |
log.info("Loading Llama-3.1-8B-Instruct (FP32 CPU-only)β¦ this can take a while.") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float32, | |
device_map={"": "cpu"}, # force CPU placement | |
) | |
model.eval() | |
log.info("Model ready β Oracle awakened.") | |
############################################################################### | |
# Global state # | |
############################################################################### | |
lock = threading.Lock() # guard state + files | |
prompts = load_prompts() # list of 100 strings | |
############################################################################### | |
# Oracle generation thread # | |
############################################################################### | |
def _init_state() -> Dict[str, Any]: | |
"""Return existing state or create a fresh one if none/finished.""" | |
state = _read_json(STATE_PATH, {}) | |
if not state or state.get("finished"): | |
prompt_idx = random.randrange(len(prompts)) | |
state = { | |
"prompt_idx": prompt_idx, | |
"prompt": prompts[prompt_idx], | |
"generated": "", # text so far | |
"tokens_done": 0, | |
"start_time": time.time(), | |
"finished": False | |
} | |
_atomic_write(STATE_PATH, state) | |
log.info(f"New Oracle prompt #{prompt_idx}: {state['prompt'][:80]}β¦") | |
return state | |
def _elapsed_str(start: float) -> str: | |
d = int(time.time() - start) | |
h, r = divmod(d, 3600) | |
m, s = divmod(r, 60) | |
return f"{h}h {m}m {s}s" | |
def oracle_loop(): | |
while True: | |
with lock: | |
state = _init_state() | |
if state["finished"]: | |
time.sleep(SECS_BETWEEN_TOKENS) | |
continue | |
# Build context: prompt + generated so far | |
context = state["prompt"] + state["generated"] | |
input_ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids | |
# Sample one token | |
with torch.no_grad(): | |
out = model.generate( | |
input_ids, | |
max_new_tokens=1, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
top_p=TOP_P, | |
) | |
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
with lock: | |
state["generated"] += next_token | |
state["tokens_done"] += 1 | |
if state["tokens_done"] >= TOKENS_PER_PROMPT: | |
state["finished"] = True | |
log.info("Prompt completed β Oracle will select a new one shortly.") | |
_atomic_write(STATE_PATH, state) | |
time.sleep(SECS_BETWEEN_TOKENS) | |
threading.Thread(target=oracle_loop, daemon=True).start() | |
############################################################################### | |
# Gradio interface # | |
############################################################################### | |
def fetch_state() -> tuple[str, str, str]: | |
state = _read_json(STATE_PATH, {}) | |
if not state: | |
return "Loadingβ¦", "", "0h 0m 0s" | |
return state["prompt"], state["generated"], _elapsed_str(state["start_time"]) | |
def submit_guess(full: str, idea: str): | |
full = full.strip() | |
idea = idea.strip() | |
if not full and not idea: | |
return gr.update(value="β οΈ Enter a guess in one of the fields."), gr.update(), gr.update() | |
prompt, generated, elapsed = fetch_state() | |
guess_text = full or idea | |
guess_type = "full" if full else "idea" | |
record = { | |
"timestamp": datetime.now(timezone.utc).isoformat(), | |
"prompt": prompt, | |
"point-in-time": elapsed, | |
"response-point": generated, | |
"user-guess": guess_text, | |
"guess-type": guess_type | |
} | |
with lock: | |
with open(DATA_PATH, "a", encoding="utf-8") as f: | |
f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
log.info(f"Logged {guess_type} guess ({len(guess_text)} chars).") | |
return gr.update(value="β Guess recorded β thanks!"), gr.update(value=""), gr.update(value="") | |
with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo: | |
gr.Markdown("""# π What Comes Next | |
Watch the Oracle craft an extended response β **one token at a time**. Predict its | |
next words or general direction and see how close you were when the tale concludes. | |
(All inputs are stored in `data.json` for research.)""") | |
prompt_md = gr.Markdown() | |
oracle_box = gr.Textbox(lines=10, interactive=False, label="π Oracle text so far") | |
elapsed_tb = gr.Textbox(interactive=False, label="β± Elapsed time") | |
refresh_btn = gr.Button("π Refresh") | |
with gr.Row(): | |
exact_tb = gr.Textbox(label="π§ Exact continuation (full)") | |
idea_tb = gr.Textbox(label="π‘ General idea") | |
submit_btn = gr.Button("Submit Guess") | |
status_tb = gr.Textbox(interactive=False, label="Status") | |
# Actions | |
refresh_btn.click(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb]) | |
demo.load(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb]) | |
submit_btn.click(submit_guess, | |
inputs=[exact_tb, idea_tb], | |
outputs=[status_tb, exact_tb, idea_tb]) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |