File size: 3,415 Bytes
b7304c4
c08680c
 
a210442
23814a9
2dd44a8
 
e18ee0c
c08680c
23814a9
421d392
b7304c4
e4b0b00
2dd44a8
 
 
 
e18ee0c
 
2dd44a8
23814a9
2dd44a8
 
 
23814a9
 
 
2dd44a8
23814a9
2dd44a8
23814a9
 
2dd44a8
23814a9
 
2dd44a8
23814a9
e18ee0c
23814a9
 
 
e18ee0c
2dd44a8
23814a9
2dd44a8
e18ee0c
23814a9
2dd44a8
23814a9
 
 
2dd44a8
 
e18ee0c
23814a9
e18ee0c
 
 
b7304c4
23814a9
e18ee0c
23814a9
 
 
 
b7304c4
23814a9
 
 
e18ee0c
23814a9
2dd44a8
e18ee0c
23814a9
 
 
 
 
 
 
 
 
 
 
 
 
e18ee0c
23814a9
 
 
 
 
e18ee0c
23814a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3




import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192


logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

def _rj(p,d):
    
    try: return json.load(open(p,encoding="utf-8"))
    except: return d
        

def _aw(p,o):
    t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)

prompts=_rj(PROMPTS_PATH,[])
if not prompts: raise Exception("no prompts")

tok=os.environ.get("HF_READ_TOKEN")
log.info("loading model...")
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME,token=tok)
model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
model.to("cpu");model.eval()
log.info("model up")

lock=threading.Lock()

def _init():
    s=_rj(STATE_PATH,{})
    if not s or s.get("finished"):
        i=random.randrange(len(prompts))
        s={"i":i,"p":prompts[i],"g":"","c":0,"t":time.time(),"finished":False}
        _aw(STATE_PATH,s)
    return s

def _es(st):
    d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60)
    return f"{h}h {m}m {s}s"

def _loop():
    while True:
        with lock: s=_init()
        if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
        c=s["p"]+s["g"]
        ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
        with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
        nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
        with lock:
            s["g"]+=nt; s["c"]+=1
            if s["c"]>=TOKENS_PER_PROMPT: s["finished"]=True
            _aw(STATE_PATH,s)
        time.sleep(SECS_PER_TOKEN)
threading.Thread(target=_loop,daemon=True).start()

def _fetch():
    s=_rj(STATE_PATH,{})
    if not s: return "...","","0h 0m 0s"
    return s["p"],s["g"],_es(s["t"])

def _sg(f,i):
    f1,i1=f.strip(),i.strip()
    if not(f1 or i1): return gr.update(value="eh?"),gr.update(),gr.update()
    p,g,e=_fetch();guess=f1 or i1;gt="full" if f1 else "idea"
    r={"ts":datetime.now(timezone.utc).isoformat(),"prompt":p,"time":e,"resp":g,"guess":guess,"type":gt}
    with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
    return gr.update(value="ok logged"),gr.update(value=""),gr.update(value="")

with gr.Blocks(theme="darkdefault") as demo:
    gr.Markdown("# What Comes Next")
    prm=gr.Markdown();txt=gr.Textbox(lines=10,interactive=False,label="oracle");tme=gr.Textbox(interactive=False,label="time")
    rbtn=gr.Button("refresh");full=gr.Textbox(label="full");idea=gr.Textbox(label="idea");send=gr.Button("send");st=gr.Textbox(interactive=False,label="status")
    demo.load(_fetch,outputs=[prm,txt,tme])
    rbtn.click(_fetch,outputs=[prm,txt,tme])
    send.click(_sg,inputs=[full,idea],outputs=[st,full,idea])

if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)