#!/usr/bin/env python3 # what comes next sloppy version import os, json, time, random, threading, logging from datetime import datetime, timezone import torch, gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" PROMPTS_PATH = "full_prompts.json" STATE_PATH = "current_state.json" DATA_PATH = "data.json" TOKENS_PER_PROMPT = 2048 SECS_BETWEEN_TOKENS = 15 TEMPERATURE = 0.9 TOP_P = 0.95 MAX_CONTEXT_TOKENS = 8192 logging.basicConfig(level=logging.INFO) log = logging.getLogger() def _read_json(p, d): try: return json.load(open(p, encoding="utf-8")) except: return d def _atomic_write(p, o): t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p) def load_prompts(): l = _read_json(PROMPTS_PATH, []) if not l: raise FileNotFoundError return l # load model (uses HF_READ_TOKEN) tok = os.environ.get("HF_READ_TOKEN") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=tok) model.to(torch.device("cpu")); model.eval() prompts = load_prompts(); lock = threading.Lock() # main loop: oracle gen def _init_state(): s = _read_json(STATE_PATH, {}) if not s or s.get("finished"): i = random.randrange(len(prompts)) s = {"prompt_idx":i, "prompt":prompts[i], "generated":"", "tokens_done":0, "start_time":time.time(), "finished":False} _atomic_write(STATE_PATH, s) return s def _elapsed_str(st): d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s" def oracle_loop(): while True: with lock: s=_init_state() if s["finished"]: time.sleep(SECS_BETWEEN_TOKENS); continue c=s["prompt"]+s["generated"] ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P) nt=tokenizer.decode(out[0,-1], skip_special_tokens=True, clean_up_tokenization_spaces=False) with lock: s["generated"]+=nt; s["tokens_done"]+=1 if s["tokens_done"]>=TOKENS_PER_PROMPT: s["finished"]=True _atomic_write(STATE_PATH, s) time.sleep(SECS_BETWEEN_TOKENS) threading.Thread(target=oracle_loop, daemon=True).start() # ui def fetch_state(): s=_read_json(STATE_PATH,{}) if not s: return "Loading...","","0h 0m 0s" return s["prompt"], s["generated"], _elapsed_str(s["start_time"]) def submit_guess(full, idea): f=full.strip(); i=idea.strip() if not (f or i): return gr.update(value="enter guess!"),gr.update(),gr.update() p,g,e=fetch_state(); guess=f or i; gt="full" if f else "idea" r={"timestamp":datetime.now(timezone.utc).isoformat(),"prompt":p,"point-in-time":e,"response-point":g,"user-guess":guess,"guess-type":gt} with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n") return gr.update(value="logged!"),gr.update(value=""),gr.update(value="") with gr.Blocks(title="What Comes Next") as demo: gr.Markdown("# What Comes Next - sloppy") prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle"); elt=gr.Textbox(interactive=False,label="time") r=gr.Button("refresh"); f=gr.Textbox(label="full guess"); i=gr.Textbox(label="idea"); sbtn=gr.Button("send"); st=gr.Textbox(interactive=False,label="st") demo.load(fetch_state,outputs=[prm,txt,elt]) r.click(fetch_state,outputs=[prm,txt,elt]); sbtn.click(submit_guess,inputs=[f,i],outputs=[st,f,i]) if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)