what-comes-next / app.py
ProCreations's picture
Update app.py
421d392 verified
raw
history blame
9.02 kB
#!/usr/bin/env python3
"""
what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
A global, slow-burn guessing game powered by Llama-3.1-8B-Instruct (FP32, CPU-only).
HOW IT WORKS
============
β€’ One shared model generates a single, very long completion (β‰ˆ2 k tokens) for a chosen
prompt in *full precision* on CPU. One token is sampled every ~15 s, so a prompt
unfolds for roughly 10 hours. All visitors see the same progress in real-time.
β€’ Players read the partial output and may submit **either**
🧠 Exact continuation (full guess) **or** πŸ’‘ General idea (summary guess).
β€’ Each guess is appended to `data.json` with prompt, Oracle progress, timestamp & type.
β€’ Offline scoring (not included here) can later measure similarity vs the final text.
"""
from __future__ import annotations
import os
import json
import time
import random
import threading
import logging
from datetime import datetime, timezone
from typing import Dict, Any
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
###############################################################################
# Configuration #
###############################################################################
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # full-precision model
PROMPTS_PATH = "full_prompts.json" # 100 full prompts
STATE_PATH = "current_state.json" # persistent Oracle state
DATA_PATH = "data.json" # JSONL log of guesses
TOKENS_PER_PROMPT = 2048 # stop after N generated tokens
SECS_BETWEEN_TOKENS = 15 # ~10 h per prompt
TEMPERATURE = 0.9 # higher creativity, as requested
TOP_P = 0.95 # nucleus sampling
MAX_CONTEXT_TOKENS = 8192 # safety cap
###############################################################################
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
log = logging.getLogger("what-comes-next")
###############################################################################
# Utility helpers #
###############################################################################
def _read_json(path: str, default: Any):
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return default
def _atomic_write(path: str, obj: Any):
tmp = f"{path}.tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2)
os.replace(tmp, path)
def load_prompts() -> list[str]:
if not os.path.exists(PROMPTS_PATH):
raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
with open(PROMPTS_PATH, "r", encoding="utf-8") as f:
prompts = json.load(f)
if not isinstance(prompts, list) or not prompts:
raise ValueError("full_prompts.json must be a non-empty JSON array of strings")
return prompts
###############################################################################
# Model loading #
###############################################################################
log.info("Loading Llama-3.1-8B-Instruct (FP32 CPU-only)… this can take a while.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map={"": "cpu"}, # force CPU placement
)
model.eval()
log.info("Model ready – Oracle awakened.")
###############################################################################
# Global state #
###############################################################################
lock = threading.Lock() # guard state + files
prompts = load_prompts() # list of 100 strings
###############################################################################
# Oracle generation thread #
###############################################################################
def _init_state() -> Dict[str, Any]:
"""Return existing state or create a fresh one if none/finished."""
state = _read_json(STATE_PATH, {})
if not state or state.get("finished"):
prompt_idx = random.randrange(len(prompts))
state = {
"prompt_idx": prompt_idx,
"prompt": prompts[prompt_idx],
"generated": "", # text so far
"tokens_done": 0,
"start_time": time.time(),
"finished": False
}
_atomic_write(STATE_PATH, state)
log.info(f"New Oracle prompt #{prompt_idx}: {state['prompt'][:80]}…")
return state
def _elapsed_str(start: float) -> str:
d = int(time.time() - start)
h, r = divmod(d, 3600)
m, s = divmod(r, 60)
return f"{h}h {m}m {s}s"
def oracle_loop():
while True:
with lock:
state = _init_state()
if state["finished"]:
time.sleep(SECS_BETWEEN_TOKENS)
continue
# Build context: prompt + generated so far
context = state["prompt"] + state["generated"]
input_ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
# Sample one token
with torch.no_grad():
out = model.generate(
input_ids,
max_new_tokens=1,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
with lock:
state["generated"] += next_token
state["tokens_done"] += 1
if state["tokens_done"] >= TOKENS_PER_PROMPT:
state["finished"] = True
log.info("Prompt completed – Oracle will select a new one shortly.")
_atomic_write(STATE_PATH, state)
time.sleep(SECS_BETWEEN_TOKENS)
threading.Thread(target=oracle_loop, daemon=True).start()
###############################################################################
# Gradio interface #
###############################################################################
def fetch_state() -> tuple[str, str, str]:
state = _read_json(STATE_PATH, {})
if not state:
return "Loading…", "", "0h 0m 0s"
return state["prompt"], state["generated"], _elapsed_str(state["start_time"])
def submit_guess(full: str, idea: str):
full = full.strip()
idea = idea.strip()
if not full and not idea:
return gr.update(value="⚠️ Enter a guess in one of the fields."), gr.update(), gr.update()
prompt, generated, elapsed = fetch_state()
guess_text = full or idea
guess_type = "full" if full else "idea"
record = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"prompt": prompt,
"point-in-time": elapsed,
"response-point": generated,
"user-guess": guess_text,
"guess-type": guess_type
}
with lock:
with open(DATA_PATH, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
log.info(f"Logged {guess_type} guess ({len(guess_text)} chars).")
return gr.update(value="βœ… Guess recorded – thanks!"), gr.update(value=""), gr.update(value="")
with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
gr.Markdown("""# 🌌 What Comes Next
Watch the Oracle craft an extended response – **one token at a time**. Predict its
next words or general direction and see how close you were when the tale concludes.
(All inputs are stored in `data.json` for research.)""")
prompt_md = gr.Markdown()
oracle_box = gr.Textbox(lines=10, interactive=False, label="πŸ“œ Oracle text so far")
elapsed_tb = gr.Textbox(interactive=False, label="⏱ Elapsed time")
refresh_btn = gr.Button("πŸ”„ Refresh")
with gr.Row():
exact_tb = gr.Textbox(label="🧠 Exact continuation (full)")
idea_tb = gr.Textbox(label="πŸ’‘ General idea")
submit_btn = gr.Button("Submit Guess")
status_tb = gr.Textbox(interactive=False, label="Status")
# Actions
refresh_btn.click(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
demo.load(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
submit_btn.click(submit_guess,
inputs=[exact_tb, idea_tb],
outputs=[status_tb, exact_tb, idea_tb])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)