ProCreations commited on
Commit
23814a9
·
verified ·
1 Parent(s): c08680c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -2,102 +2,93 @@
2
 
3
 
4
 
 
5
  import os, json, time, random, threading, logging
6
  from datetime import datetime, timezone
7
  import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
8
- import gradio as gr
9
 
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
- from gradio.themes import Dark
12
 
13
  MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
14
  PROMPTS_PATH = "full_prompts.json"
15
  STATE_PATH = "current_state.json"
16
  DATA_PATH = "data.json"
17
-
18
  TOKENS_PER_PROMPT = 2048
19
  SECS_PER_TOKEN = 15
20
  TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
21
 
 
22
  logging.basicConfig(level=logging.INFO)
23
  log = logging.getLogger()
24
 
25
- # read or write json
26
-
27
- def _rj(p, d):
28
- try: return json.load(open(p, encoding="utf-8"))
29
  except: return d
 
30
 
31
- def _aw(p, o):
32
- t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t, p)
33
 
34
- # load prompts
35
- tmp = _rj(PROMPTS_PATH, [])
36
- if not tmp: raise Exception("no prompts")
37
- prompts = tmp
38
 
39
- # load model
40
- tok = os.environ.get("HF_READ_TOKEN")
41
  log.info("loading model...")
42
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
43
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok)
44
- model.to("cpu"); model.eval()
45
  log.info("model up")
46
 
47
- lock = threading.Lock()
48
 
49
  def _init():
50
- s = _rj(STATE_PATH, {})
51
  if not s or s.get("finished"):
52
- i = random.randrange(len(prompts))
53
- s = {"i": i, "p": prompts[i], "g": "", "c": 0, "t": time.time(), "finished": False}
54
- _aw(STATE_PATH, s)
55
  return s
56
 
57
- # elapsed time
58
-
59
  def _es(st):
60
- d = int(time.time() - st); h, r = divmod(d, 3600); m, s = divmod(r, 60)
61
  return f"{h}h {m}m {s}s"
62
 
63
- # oracle loop
64
-
65
  def _loop():
66
  while True:
67
- with lock: s = _init()
68
  if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
69
- c = s["p"] + s["g"]
70
- ids = tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
71
- with torch.no_grad(): out = model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P)
72
- nt = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
73
  with lock:
74
- s["g"] += nt; s["c"] += 1
75
- if s["c"] >= TOKENS_PER_PROMPT: s["finished"] = True
76
- _aw(STATE_PATH, s)
77
  time.sleep(SECS_PER_TOKEN)
78
-
79
- threading.Thread(target=_loop, daemon=True).start()
80
-
81
- # ui
82
 
83
  def _fetch():
84
- s = _rj(STATE_PATH, {})
85
- if not s: return "...", "", "0h 0m 0s"
86
- return s["p"], s["g"], _es(s["t"])
87
-
88
- def _sg(f, i):
89
- f1, f2 = f.strip(), i.strip()
90
- if not (f1 or f2): return gr.update(value="eh?"), gr.update(), gr.update()
91
- p, g, e = _fetch(); guess = f1 or f2; gt = "full" if f1 else "idea"
92
- r = {"ts": datetime.now(timezone.utc).isoformat(), "p": p, "time": e, "resp": g, "guess": guess, "type": gt}
93
- with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(r, ensure_ascii=False) + "\n")
94
- return gr.update(value="ok logged"), gr.update(value=""), gr.update(value="")
95
-
96
- with gr.Blocks(theme=Dark()) as demo:
97
  gr.Markdown("# What Comes Next")
98
- prm = gr.Markdown(); txt = gr.Textbox(lines=10, interactive=False, label="oracle"); tme = gr.Textbox(interactive=False, label="time")
99
- rbtn = gr.Button("refresh"); full = gr.Textbox(label="full"); idea = gr.Textbox(label="idea"); send = gr.Button("send"); st = gr.Textbox(interactive=False, label="st")
100
- demo.load(_fetch, outputs=[prm, txt, tme]); rbtn.click(_fetch, outputs=[prm, txt, tme]); send.click(_sg, inputs=[full, idea], outputs=[st, full, idea])
 
 
101
 
102
- if __name__ == "__main__":
103
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
 
3
 
4
 
5
+
6
  import os, json, time, random, threading, logging
7
  from datetime import datetime, timezone
8
  import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
 
9
 
10
+ import gradio as gr
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
12
 
13
  MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
14
  PROMPTS_PATH = "full_prompts.json"
15
  STATE_PATH = "current_state.json"
16
  DATA_PATH = "data.json"
 
17
  TOKENS_PER_PROMPT = 2048
18
  SECS_PER_TOKEN = 15
19
  TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
20
 
21
+
22
  logging.basicConfig(level=logging.INFO)
23
  log = logging.getLogger()
24
 
25
+ def _rj(p,d):
26
+
27
+ try: return json.load(open(p,encoding="utf-8"))
 
28
  except: return d
29
+
30
 
31
+ def _aw(p,o):
32
+ t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)
33
 
34
+ prompts=_rj(PROMPTS_PATH,[])
35
+ if not prompts: raise Exception("no prompts")
 
 
36
 
37
+ tok=os.environ.get("HF_READ_TOKEN")
 
38
  log.info("loading model...")
39
+ tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME,token=tok)
40
+ model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
41
+ model.to("cpu");model.eval()
42
  log.info("model up")
43
 
44
+ lock=threading.Lock()
45
 
46
  def _init():
47
+ s=_rj(STATE_PATH,{})
48
  if not s or s.get("finished"):
49
+ i=random.randrange(len(prompts))
50
+ s={"i":i,"p":prompts[i],"g":"","c":0,"t":time.time(),"finished":False}
51
+ _aw(STATE_PATH,s)
52
  return s
53
 
 
 
54
  def _es(st):
55
+ d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60)
56
  return f"{h}h {m}m {s}s"
57
 
 
 
58
  def _loop():
59
  while True:
60
+ with lock: s=_init()
61
  if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
62
+ c=s["p"]+s["g"]
63
+ ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
64
+ with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
65
+ nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
66
  with lock:
67
+ s["g"]+=nt; s["c"]+=1
68
+ if s["c"]>=TOKENS_PER_PROMPT: s["finished"]=True
69
+ _aw(STATE_PATH,s)
70
  time.sleep(SECS_PER_TOKEN)
71
+ threading.Thread(target=_loop,daemon=True).start()
 
 
 
72
 
73
  def _fetch():
74
+ s=_rj(STATE_PATH,{})
75
+ if not s: return "...","","0h 0m 0s"
76
+ return s["p"],s["g"],_es(s["t"])
77
+
78
+ def _sg(f,i):
79
+ f1,i1=f.strip(),i.strip()
80
+ if not(f1 or i1): return gr.update(value="eh?"),gr.update(),gr.update()
81
+ p,g,e=_fetch();guess=f1 or i1;gt="full" if f1 else "idea"
82
+ r={"ts":datetime.now(timezone.utc).isoformat(),"prompt":p,"time":e,"resp":g,"guess":guess,"type":gt}
83
+ with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
84
+ return gr.update(value="ok logged"),gr.update(value=""),gr.update(value="")
85
+
86
+ with gr.Blocks(theme="darkdefault") as demo:
87
  gr.Markdown("# What Comes Next")
88
+ prm=gr.Markdown();txt=gr.Textbox(lines=10,interactive=False,label="oracle");tme=gr.Textbox(interactive=False,label="time")
89
+ rbtn=gr.Button("refresh");full=gr.Textbox(label="full");idea=gr.Textbox(label="idea");send=gr.Button("send");st=gr.Textbox(interactive=False,label="status")
90
+ demo.load(_fetch,outputs=[prm,txt,tme])
91
+ rbtn.click(_fetch,outputs=[prm,txt,tme])
92
+ send.click(_sg,inputs=[full,idea],outputs=[st,full,idea])
93
 
94
+ if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)