File size: 9,023 Bytes
b7304c4
 
421d392
 
 
 
 
 
 
 
 
 
 
 
b7304c4
 
421d392
 
a210442
 
b7304c4
a210442
 
 
b7304c4
 
a210442
 
b7304c4
421d392
b7304c4
 
421d392
b7304c4
421d392
 
 
 
 
 
 
 
 
 
b7304c4
 
 
421d392
b7304c4
421d392
 
 
b7304c4
 
 
421d392
b7304c4
 
 
 
 
421d392
b7304c4
421d392
b7304c4
 
a210442
 
b7304c4
 
 
421d392
 
 
 
 
b7304c4
421d392
 
 
 
b7304c4
 
a210442
b7304c4
a210442
421d392
a210442
 
421d392
b7304c4
421d392
 
 
 
 
 
 
b7304c4
421d392
b7304c4
421d392
 
b7304c4
421d392
b7304c4
 
 
421d392
 
 
b7304c4
421d392
b7304c4
421d392
 
b7304c4
 
 
421d392
 
 
 
 
 
 
b7304c4
 
 
421d392
 
 
 
 
 
 
 
 
 
b7304c4
421d392
b7304c4
 
 
 
 
 
421d392
b7304c4
 
421d392
b7304c4
 
 
421d392
 
 
b7304c4
 
 
421d392
 
 
b7304c4
421d392
 
 
 
 
b7304c4
 
421d392
 
 
 
 
b7304c4
421d392
 
 
b7304c4
 
 
421d392
 
 
 
 
b7304c4
 
421d392
b7304c4
421d392
 
b7304c4
 
421d392
 
 
 
 
b7304c4
421d392
 
 
b7304c4
421d392
b7304c4
421d392
 
 
 
 
 
 
 
 
 
 
 
 
a210442
 
b7304c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python3
"""
what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
A global, slow-burn guessing game powered by Llama-3.1-8B-Instruct (FP32, CPU-only).

HOW IT WORKS
============
β€’ One shared model generates a single, very long completion (β‰ˆ2 k tokens) for a chosen
  prompt in *full precision* on CPU.  One token is sampled every ~15 s, so a prompt
  unfolds for roughly 10 hours.  All visitors see the same progress in real-time.
β€’ Players read the partial output and may submit **either**
    🧠 Exact continuation (full guess) **or** πŸ’‘ General idea (summary guess).
β€’ Each guess is appended to `data.json` with prompt, Oracle progress, timestamp & type.
β€’ Offline scoring (not included here) can later measure similarity vs the final text.
"""

from __future__ import annotations

import os
import json
import time
import random
import threading
import logging
from datetime import datetime, timezone
from typing import Dict, Any

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

###############################################################################
# Configuration                                                                #
###############################################################################
MODEL_NAME           = "meta-llama/Llama-3.1-8B-Instruct"  # full-precision model
PROMPTS_PATH         = "full_prompts.json"                 # 100 full prompts
STATE_PATH           = "current_state.json"                # persistent Oracle state
DATA_PATH            = "data.json"                         # JSONL log of guesses

TOKENS_PER_PROMPT    = 2048        # stop after N generated tokens
SECS_BETWEEN_TOKENS  = 15          # ~10 h per prompt
TEMPERATURE          = 0.9         # higher creativity, as requested
TOP_P                = 0.95        # nucleus sampling
MAX_CONTEXT_TOKENS   = 8192        # safety cap
###############################################################################

logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
log = logging.getLogger("what-comes-next")

###############################################################################
# Utility helpers                                                              #
###############################################################################

def _read_json(path: str, default: Any):
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        return default


def _atomic_write(path: str, obj: Any):
    tmp = f"{path}.tmp"
    with open(tmp, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)
    os.replace(tmp, path)


def load_prompts() -> list[str]:
    if not os.path.exists(PROMPTS_PATH):
        raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
    with open(PROMPTS_PATH, "r", encoding="utf-8") as f:
        prompts = json.load(f)
    if not isinstance(prompts, list) or not prompts:
        raise ValueError("full_prompts.json must be a non-empty JSON array of strings")
    return prompts

###############################################################################
# Model loading                                                                #
###############################################################################
log.info("Loading Llama-3.1-8B-Instruct (FP32 CPU-only)… this can take a while.")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    device_map={"": "cpu"},  # force CPU placement
)
model.eval()
log.info("Model ready – Oracle awakened.")

###############################################################################
# Global state                                                                 #
###############################################################################
lock = threading.Lock()          # guard state + files
prompts = load_prompts()         # list of 100 strings

###############################################################################
# Oracle generation thread                                                     #
###############################################################################

def _init_state() -> Dict[str, Any]:
    """Return existing state or create a fresh one if none/finished."""
    state = _read_json(STATE_PATH, {})
    if not state or state.get("finished"):
        prompt_idx = random.randrange(len(prompts))
        state = {
            "prompt_idx": prompt_idx,
            "prompt": prompts[prompt_idx],
            "generated": "",          # text so far
            "tokens_done": 0,
            "start_time": time.time(),
            "finished": False
        }
        _atomic_write(STATE_PATH, state)
        log.info(f"New Oracle prompt #{prompt_idx}: {state['prompt'][:80]}…")
    return state


def _elapsed_str(start: float) -> str:
    d = int(time.time() - start)
    h, r = divmod(d, 3600)
    m, s = divmod(r, 60)
    return f"{h}h {m}m {s}s"


def oracle_loop():
    while True:
        with lock:
            state = _init_state()
        if state["finished"]:
            time.sleep(SECS_BETWEEN_TOKENS)
            continue

        # Build context: prompt + generated so far
        context = state["prompt"] + state["generated"]
        input_ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids

        # Sample one token
        with torch.no_grad():
            out = model.generate(
                input_ids,
                max_new_tokens=1,
                do_sample=True,
                temperature=TEMPERATURE,
                top_p=TOP_P,
            )
        next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)

        with lock:
            state["generated"] += next_token
            state["tokens_done"] += 1
            if state["tokens_done"] >= TOKENS_PER_PROMPT:
                state["finished"] = True
                log.info("Prompt completed – Oracle will select a new one shortly.")
            _atomic_write(STATE_PATH, state)
        time.sleep(SECS_BETWEEN_TOKENS)

threading.Thread(target=oracle_loop, daemon=True).start()

###############################################################################
# Gradio interface                                                             #
###############################################################################

def fetch_state() -> tuple[str, str, str]:
    state = _read_json(STATE_PATH, {})
    if not state:
        return "Loading…", "", "0h 0m 0s"
    return state["prompt"], state["generated"], _elapsed_str(state["start_time"])


def submit_guess(full: str, idea: str):
    full = full.strip()
    idea = idea.strip()
    if not full and not idea:
        return gr.update(value="⚠️ Enter a guess in one of the fields."), gr.update(), gr.update()

    prompt, generated, elapsed = fetch_state()
    guess_text = full or idea
    guess_type = "full" if full else "idea"

    record = {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "prompt": prompt,
        "point-in-time": elapsed,
        "response-point": generated,
        "user-guess": guess_text,
        "guess-type": guess_type
    }
    with lock:
        with open(DATA_PATH, "a", encoding="utf-8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
    log.info(f"Logged {guess_type} guess ({len(guess_text)} chars).")
    return gr.update(value="βœ… Guess recorded – thanks!"), gr.update(value=""), gr.update(value="")


with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
    gr.Markdown("""# 🌌 What Comes Next
Watch the Oracle craft an extended response – **one token at a time**. Predict its
next words or general direction and see how close you were when the tale concludes.
(All inputs are stored in `data.json` for research.)""")

    prompt_md   = gr.Markdown()
    oracle_box  = gr.Textbox(lines=10, interactive=False, label="πŸ“œ Oracle text so far")
    elapsed_tb  = gr.Textbox(interactive=False, label="⏱ Elapsed time")

    refresh_btn = gr.Button("πŸ”„ Refresh")

    with gr.Row():
        exact_tb = gr.Textbox(label="🧠 Exact continuation (full)")
        idea_tb  = gr.Textbox(label="πŸ’‘ General idea")
    submit_btn  = gr.Button("Submit Guess")
    status_tb   = gr.Textbox(interactive=False, label="Status")

    # Actions
    refresh_btn.click(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
    demo.load(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])

    submit_btn.click(submit_guess,
                     inputs=[exact_tb, idea_tb],
                     outputs=[status_tb, exact_tb, idea_tb])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)