File size: 9,882 Bytes
b7304c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a210442
 
b7304c4
a210442
 
 
b7304c4
 
 
a210442
 
b7304c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a210442
 
b7304c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a210442
b7304c4
a210442
b7304c4
a210442
 
b7304c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a210442
b7304c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a210442
 
b7304c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#!/usr/bin/env python3
"""
what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
A slow, contemplative global guessing game.

🔮  HOW IT WORKS  🔮
• A single Llama‑3.1‑8B‑Instruct model (FP32 on CPU) is generating one very long completion
  for a chosen mystical prompt. It runs continuously in the background for everyone.
• Any visitor sees the same prompt and the Oracle’s current partial response.
• Players may submit *one* of two kinds of guesses:
    1. 🧠  **Exact Completion** – the full sentence/paragraph they think the Oracle will
       eventually write.
    2. 💡  **General Idea**   – a short summary of the direction or theme they expect.
• Each guess is recorded immediately (with timestamp, Oracle progress, etc.) to
  `data.json` (JSON‑Lines).  When the Oracle finally finishes, offline evaluation can
  score the guesses against the final text.

The game then moves on to the next prompt and the cycle repeats.
"""

import os
import json
import time
import random
import threading
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr

###############################################################################
# Settings                                                                     #
###############################################################################
MODEL_NAME           = "meta-llama/Llama-3.1-8B-Instruct"       # FP32, CPU‑only
PROMPTS_PATH         = "oracle_prompts.json"                    # 100 unfinished lines
STATE_PATH           = "current_state.json"                     # persistent Oracle state
DATA_PATH            = "data.json"                              # JSONL of user guesses
TOKENS_PER_PROMPT    = 2048                                      # stop after N tokens
SECS_BETWEEN_TOKENS  = 15                                        # pacing (≈10h / prompt)
TEMPERATURE          = 0.8
TOP_P                = 0.95
MAX_CONTEXT_TOKENS   = 8192
###############################################################################

logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
log = logging.getLogger("what‑comes‑next")

lock = threading.Lock()  # global file/variable lock

# --------------------------------------------------------------------------- #
# Helper functions                                                             #
# --------------------------------------------------------------------------- #

def _read_json(path: str, default: Any):
    try:
        with open(path, "r", encoding="utf‑8") as f:
            return json.load(f)
    except FileNotFoundError:
        return default


def _write_json(path: str, obj: Any):
    tmp = f"{path}.tmp"
    with open(tmp, "w", encoding="utf‑8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)
    os.replace(tmp, path)


def load_prompts() -> list[str]:
    if not os.path.exists(PROMPTS_PATH):
        raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
    with open(PROMPTS_PATH, "r", encoding="utf‑8") as f:
        return json.load(f)


prompts = load_prompts()

# --------------------------------------------------------------------------- #
# Model loading (FP32 ‑ CPU)                                                   #
# --------------------------------------------------------------------------- #
log.info("Loading Llama‑3.1‑8B‑Instruct in FP32 on CPU (this is *slow*) …")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    device_map={"": "cpu"},  # force CPU
)
model.eval()
log.info("Model loaded.")

# --------------------------------------------------------------------------- #
# Oracle generation thread                                                     #
# --------------------------------------------------------------------------- #

def init_state() -> Dict[str, Any]:
    """Return existing state or create a new one."""
    state = _read_json(STATE_PATH, {})
    if state.get("finished", False):
        state = {}  # finished, start new prompt
    if not state:
        prompt_idx = random.randrange(len(prompts))
        prompt = prompts[prompt_idx]
        state = {
            "prompt_idx": prompt_idx,
            "prompt": prompt,
            "generated": "",         # Oracle’s text so far (string)
            "start_time": time.time(),
            "finished": False,
            "tokens_done": 0
        }
        _write_json(STATE_PATH, state)
        log.info(f"Starting new Oracle prompt #{prompt_idx}: {prompt[:60]}…")
    return state


def oracle_loop():
    """Continuously extend the Oracle’s text by one token every SECS_BETWEEN_TOKENS."""
    while True:
        with lock:
            state = init_state()
            if state["finished"]:
                # Should not happen, but guard anyway
                time.sleep(SECS_BETWEEN_TOKENS)
                continue
            prompt_text = state["prompt"]
            generated_text = state["generated"]
            tokens_done = state["tokens_done"]

        # Build input_ids (prompt + generated so far)
        full_input = prompt_text + generated_text
        input_ids = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids

        # Generate ONE token
        with torch.no_grad():
            outputs = model.generate(
                input_ids,
                max_new_tokens=1,
                do_sample=True,
                temperature=TEMPERATURE,
                top_p=TOP_P,
            )
            next_token_id = outputs[0, -1].unsqueeze(0)
            next_token_text = tokenizer.decode(next_token_id, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        with lock:
            # Update state
            state["generated"] += next_token_text
            state["tokens_done"] += 1
            if state["tokens_done"] >= TOKENS_PER_PROMPT:
                state["finished"] = True
                log.info("Prompt complete. Oracle will pick a new one next cycle.")
            _write_json(STATE_PATH, state)
        time.sleep(SECS_BETWEEN_TOKENS)  # pacing


threading.Thread(target=oracle_loop, daemon=True).start()

# --------------------------------------------------------------------------- #
# Gradio Interface                                                             #
# --------------------------------------------------------------------------- #

def human_readable_elapsed(start: float) -> str:
    delta = int(time.time() - start)
    h, rem = divmod(delta, 3600)
    m, s = divmod(rem, 60)
    return f"{h}{m}{s}s"


def get_current_state() -> Dict[str, Any]:
    with lock:
        state = _read_json(STATE_PATH, {})
    if not state:
        return {"prompt": "…loading…", "generated": "", "elapsed": "0h 0m 0s"}
    return {
        "prompt": state["prompt"],
        "generated": state["generated"],
        "elapsed": human_readable_elapsed(state["start_time"])
    }


def record_guess(full_guess: str, idea_guess: str):
    state = get_current_state()
    guess_text = full_guess.strip() or idea_guess.strip()
    if not guess_text:
        return gr.update(value="⚠️ Please enter a guess in one of the boxes …"), gr.update()
    guess_type = "full" if full_guess.strip() else "idea"
    record = {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "prompt": state["prompt"],
        "point‑in‑time": state["elapsed"],
        "response‑point": state["generated"],
        "user‑guess": guess_text,
        "guess‑type": guess_type
    }
    # Append to JSONL (data.json)
    with lock:
        with open(DATA_PATH, "a", encoding="utf‑8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
    log.info(f"Recorded {guess_type} guess ({len(guess_text)} chars).")
    return gr.update(value="✅ Guess recorded – check back when the Oracle finishes!"), gr.update(value="")


with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
    gr.Markdown("""# ✨ What Comes Next
A global, slow‑burn guessing game. The Oracle is continuously writing its story.
Read the prompt, see the Oracle’s progress, and predict **what comes next**!  
*(FP32 CPU inference – deliberately unhurried.)*""")

    ### Live Oracle view
    prompt_box     = gr.Markdown(label="🔮 Current Oracle Prompt")
    oracle_box     = gr.Textbox(label="📜 Oracle’s current text", lines=10, interactive=False)
    elapsed_box    = gr.Textbox(label="⏱️ Elapsed", interactive=False)

    ### Guess inputs
    gr.Markdown("**Make your prediction:** Fill **either** the exact continuation *or* a general idea.")
    with gr.Row():
        full_guess  = gr.Textbox(label="🧠 Exact continuation (full)")
        idea_guess  = gr.Textbox(label="💡 General idea")
    submit_btn     = gr.Button("Submit Guess")
    status_msg     = gr.Textbox(label="Status", interactive=False)

    ### Refresh button
    refresh_btn    = gr.Button("🔄 Refresh Oracle progress")

    def refresh():
        st = get_current_state()
        return st["prompt"], st["generated"], st["elapsed"]

    refresh_btn.click(refresh, outputs=[prompt_box, oracle_box, elapsed_box])
    demo.load(refresh, outputs=[prompt_box, oracle_box, elapsed_box])  # auto‑load on launch

    submit_btn.click(record_guess,
                     inputs=[full_guess, idea_guess],
                     outputs=[status_msg, full_guess])  # clear full_guess box on success

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)