Spaces:
Running
Running
File size: 3,415 Bytes
b7304c4 c08680c a210442 23814a9 2dd44a8 e18ee0c c08680c 23814a9 421d392 b7304c4 e4b0b00 2dd44a8 e18ee0c 2dd44a8 23814a9 2dd44a8 23814a9 2dd44a8 23814a9 2dd44a8 23814a9 2dd44a8 23814a9 2dd44a8 23814a9 e18ee0c 23814a9 e18ee0c 2dd44a8 23814a9 2dd44a8 e18ee0c 23814a9 2dd44a8 23814a9 2dd44a8 e18ee0c 23814a9 e18ee0c b7304c4 23814a9 e18ee0c 23814a9 b7304c4 23814a9 e18ee0c 23814a9 2dd44a8 e18ee0c 23814a9 e18ee0c 23814a9 e18ee0c 23814a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
#!/usr/bin/env python3
import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
def _rj(p,d):
try: return json.load(open(p,encoding="utf-8"))
except: return d
def _aw(p,o):
t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)
prompts=_rj(PROMPTS_PATH,[])
if not prompts: raise Exception("no prompts")
tok=os.environ.get("HF_READ_TOKEN")
log.info("loading model...")
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME,token=tok)
model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
model.to("cpu");model.eval()
log.info("model up")
lock=threading.Lock()
def _init():
s=_rj(STATE_PATH,{})
if not s or s.get("finished"):
i=random.randrange(len(prompts))
s={"i":i,"p":prompts[i],"g":"","c":0,"t":time.time(),"finished":False}
_aw(STATE_PATH,s)
return s
def _es(st):
d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60)
return f"{h}h {m}m {s}s"
def _loop():
while True:
with lock: s=_init()
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
c=s["p"]+s["g"]
ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
with lock:
s["g"]+=nt; s["c"]+=1
if s["c"]>=TOKENS_PER_PROMPT: s["finished"]=True
_aw(STATE_PATH,s)
time.sleep(SECS_PER_TOKEN)
threading.Thread(target=_loop,daemon=True).start()
def _fetch():
s=_rj(STATE_PATH,{})
if not s: return "...","","0h 0m 0s"
return s["p"],s["g"],_es(s["t"])
def _sg(f,i):
f1,i1=f.strip(),i.strip()
if not(f1 or i1): return gr.update(value="eh?"),gr.update(),gr.update()
p,g,e=_fetch();guess=f1 or i1;gt="full" if f1 else "idea"
r={"ts":datetime.now(timezone.utc).isoformat(),"prompt":p,"time":e,"resp":g,"guess":guess,"type":gt}
with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
return gr.update(value="ok logged"),gr.update(value=""),gr.update(value="")
with gr.Blocks(theme="darkdefault") as demo:
gr.Markdown("# What Comes Next")
prm=gr.Markdown();txt=gr.Textbox(lines=10,interactive=False,label="oracle");tme=gr.Textbox(interactive=False,label="time")
rbtn=gr.Button("refresh");full=gr.Textbox(label="full");idea=gr.Textbox(label="idea");send=gr.Button("send");st=gr.Textbox(interactive=False,label="status")
demo.load(_fetch,outputs=[prm,txt,tme])
rbtn.click(_fetch,outputs=[prm,txt,tme])
send.click(_sg,inputs=[full,idea],outputs=[st,full,idea])
if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)
|