Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,102 +2,93 @@
|
|
2 |
|
3 |
|
4 |
|
|
|
5 |
import os, json, time, random, threading, logging
|
6 |
from datetime import datetime, timezone
|
7 |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
8 |
-
import gradio as gr
|
9 |
|
|
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
-
from gradio.themes import Dark
|
12 |
|
13 |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
14 |
PROMPTS_PATH = "full_prompts.json"
|
15 |
STATE_PATH = "current_state.json"
|
16 |
DATA_PATH = "data.json"
|
17 |
-
|
18 |
TOKENS_PER_PROMPT = 2048
|
19 |
SECS_PER_TOKEN = 15
|
20 |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
21 |
|
|
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
log = logging.getLogger()
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
try: return json.load(open(p, encoding="utf-8"))
|
29 |
except: return d
|
|
|
30 |
|
31 |
-
def _aw(p,
|
32 |
-
t
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
if not tmp: raise Exception("no prompts")
|
37 |
-
prompts = tmp
|
38 |
|
39 |
-
|
40 |
-
tok = os.environ.get("HF_READ_TOKEN")
|
41 |
log.info("loading model...")
|
42 |
-
tokenizer
|
43 |
-
model
|
44 |
-
model.to("cpu");
|
45 |
log.info("model up")
|
46 |
|
47 |
-
lock
|
48 |
|
49 |
def _init():
|
50 |
-
s
|
51 |
if not s or s.get("finished"):
|
52 |
-
i
|
53 |
-
s
|
54 |
-
_aw(STATE_PATH,
|
55 |
return s
|
56 |
|
57 |
-
# elapsed time
|
58 |
-
|
59 |
def _es(st):
|
60 |
-
d
|
61 |
return f"{h}h {m}m {s}s"
|
62 |
|
63 |
-
# oracle loop
|
64 |
-
|
65 |
def _loop():
|
66 |
while True:
|
67 |
-
with lock: s
|
68 |
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
|
69 |
-
c
|
70 |
-
ids
|
71 |
-
with torch.no_grad(): out
|
72 |
-
nt
|
73 |
with lock:
|
74 |
-
s["g"]
|
75 |
-
if s["c"]
|
76 |
-
_aw(STATE_PATH,
|
77 |
time.sleep(SECS_PER_TOKEN)
|
78 |
-
|
79 |
-
threading.Thread(target=_loop, daemon=True).start()
|
80 |
-
|
81 |
-
# ui
|
82 |
|
83 |
def _fetch():
|
84 |
-
s
|
85 |
-
if not s: return "...",
|
86 |
-
return s["p"],
|
87 |
-
|
88 |
-
def _sg(f,
|
89 |
-
f1,
|
90 |
-
if not
|
91 |
-
p,
|
92 |
-
r
|
93 |
-
with lock: open(DATA_PATH,
|
94 |
-
return gr.update(value="ok logged"),
|
95 |
-
|
96 |
-
with gr.Blocks(theme=
|
97 |
gr.Markdown("# What Comes Next")
|
98 |
-
prm
|
99 |
-
rbtn
|
100 |
-
demo.load(_fetch,
|
|
|
|
|
101 |
|
102 |
-
if __name__
|
103 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
2 |
|
3 |
|
4 |
|
5 |
+
|
6 |
import os, json, time, random, threading, logging
|
7 |
from datetime import datetime, timezone
|
8 |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
|
|
9 |
|
10 |
+
import gradio as gr
|
11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
12 |
|
13 |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
14 |
PROMPTS_PATH = "full_prompts.json"
|
15 |
STATE_PATH = "current_state.json"
|
16 |
DATA_PATH = "data.json"
|
|
|
17 |
TOKENS_PER_PROMPT = 2048
|
18 |
SECS_PER_TOKEN = 15
|
19 |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
20 |
|
21 |
+
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
log = logging.getLogger()
|
24 |
|
25 |
+
def _rj(p,d):
|
26 |
+
|
27 |
+
try: return json.load(open(p,encoding="utf-8"))
|
|
|
28 |
except: return d
|
29 |
+
|
30 |
|
31 |
+
def _aw(p,o):
|
32 |
+
t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)
|
33 |
|
34 |
+
prompts=_rj(PROMPTS_PATH,[])
|
35 |
+
if not prompts: raise Exception("no prompts")
|
|
|
|
|
36 |
|
37 |
+
tok=os.environ.get("HF_READ_TOKEN")
|
|
|
38 |
log.info("loading model...")
|
39 |
+
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME,token=tok)
|
40 |
+
model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
|
41 |
+
model.to("cpu");model.eval()
|
42 |
log.info("model up")
|
43 |
|
44 |
+
lock=threading.Lock()
|
45 |
|
46 |
def _init():
|
47 |
+
s=_rj(STATE_PATH,{})
|
48 |
if not s or s.get("finished"):
|
49 |
+
i=random.randrange(len(prompts))
|
50 |
+
s={"i":i,"p":prompts[i],"g":"","c":0,"t":time.time(),"finished":False}
|
51 |
+
_aw(STATE_PATH,s)
|
52 |
return s
|
53 |
|
|
|
|
|
54 |
def _es(st):
|
55 |
+
d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60)
|
56 |
return f"{h}h {m}m {s}s"
|
57 |
|
|
|
|
|
58 |
def _loop():
|
59 |
while True:
|
60 |
+
with lock: s=_init()
|
61 |
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
|
62 |
+
c=s["p"]+s["g"]
|
63 |
+
ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
|
64 |
+
with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
|
65 |
+
nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
|
66 |
with lock:
|
67 |
+
s["g"]+=nt; s["c"]+=1
|
68 |
+
if s["c"]>=TOKENS_PER_PROMPT: s["finished"]=True
|
69 |
+
_aw(STATE_PATH,s)
|
70 |
time.sleep(SECS_PER_TOKEN)
|
71 |
+
threading.Thread(target=_loop,daemon=True).start()
|
|
|
|
|
|
|
72 |
|
73 |
def _fetch():
|
74 |
+
s=_rj(STATE_PATH,{})
|
75 |
+
if not s: return "...","","0h 0m 0s"
|
76 |
+
return s["p"],s["g"],_es(s["t"])
|
77 |
+
|
78 |
+
def _sg(f,i):
|
79 |
+
f1,i1=f.strip(),i.strip()
|
80 |
+
if not(f1 or i1): return gr.update(value="eh?"),gr.update(),gr.update()
|
81 |
+
p,g,e=_fetch();guess=f1 or i1;gt="full" if f1 else "idea"
|
82 |
+
r={"ts":datetime.now(timezone.utc).isoformat(),"prompt":p,"time":e,"resp":g,"guess":guess,"type":gt}
|
83 |
+
with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
|
84 |
+
return gr.update(value="ok logged"),gr.update(value=""),gr.update(value="")
|
85 |
+
|
86 |
+
with gr.Blocks(theme="darkdefault") as demo:
|
87 |
gr.Markdown("# What Comes Next")
|
88 |
+
prm=gr.Markdown();txt=gr.Textbox(lines=10,interactive=False,label="oracle");tme=gr.Textbox(interactive=False,label="time")
|
89 |
+
rbtn=gr.Button("refresh");full=gr.Textbox(label="full");idea=gr.Textbox(label="idea");send=gr.Button("send");st=gr.Textbox(interactive=False,label="status")
|
90 |
+
demo.load(_fetch,outputs=[prm,txt,tme])
|
91 |
+
rbtn.click(_fetch,outputs=[prm,txt,tme])
|
92 |
+
send.click(_sg,inputs=[full,idea],outputs=[st,full,idea])
|
93 |
|
94 |
+
if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)
|
|