Spaces:
Running
Running
#!/usr/bin/env python3 | |
import os, json, time, random, threading, logging | |
from datetime import datetime, timezone | |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count()) | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
PROMPTS_PATH = "full_prompts.json" | |
STATE_PATH = "current_state.json" | |
DATA_PATH = "data.json" | |
TOKENS_PER_PROMPT = 2048 | |
SECS_PER_TOKEN = 15 | |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192 | |
logging.basicConfig(level=logging.INFO) | |
log = logging.getLogger() | |
def _rj(p,d): | |
try: return json.load(open(p,encoding="utf-8")) | |
except: return d | |
def _aw(p,o): | |
t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p) | |
prompts=_rj(PROMPTS_PATH,[]) | |
if not prompts: raise Exception("no prompts") | |
tok=os.environ.get("HF_READ_TOKEN") | |
log.info("loading model...") | |
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME,token=tok) | |
model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok) | |
model.to("cpu");model.eval() | |
log.info("model up") | |
lock=threading.Lock() | |
def _init(): | |
s=_rj(STATE_PATH,{}) | |
if not s or s.get("finished"): | |
i=random.randrange(len(prompts)) | |
s={"i":i,"p":prompts[i],"g":"","c":0,"t":time.time(),"finished":False} | |
_aw(STATE_PATH,s) | |
return s | |
def _es(st): | |
d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60) | |
return f"{h}h {m}m {s}s" | |
def _loop(): | |
while True: | |
with lock: s=_init() | |
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue | |
c=s["p"]+s["g"] | |
ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids | |
with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P) | |
nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False) | |
with lock: | |
s["g"]+=nt; s["c"]+=1 | |
if s["c"]>=TOKENS_PER_PROMPT: s["finished"]=True | |
_aw(STATE_PATH,s) | |
time.sleep(SECS_PER_TOKEN) | |
threading.Thread(target=_loop,daemon=True).start() | |
def _fetch(): | |
s=_rj(STATE_PATH,{}) | |
if not s: return "...","","0h 0m 0s" | |
return s["p"],s["g"],_es(s["t"]) | |
def _sg(f,i): | |
f1,i1=f.strip(),i.strip() | |
if not(f1 or i1): return gr.update(value="eh?"),gr.update(),gr.update() | |
p,g,e=_fetch();guess=f1 or i1;gt="full" if f1 else "idea" | |
r={"ts":datetime.now(timezone.utc).isoformat(),"prompt":p,"time":e,"resp":g,"guess":guess,"type":gt} | |
with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n") | |
return gr.update(value="ok logged"),gr.update(value=""),gr.update(value="") | |
with gr.Blocks(theme="darkdefault") as demo: | |
gr.Markdown("# What Comes Next") | |
prm=gr.Markdown();txt=gr.Textbox(lines=10,interactive=False,label="oracle");tme=gr.Textbox(interactive=False,label="time") | |
rbtn=gr.Button("refresh");full=gr.Textbox(label="full");idea=gr.Textbox(label="idea");send=gr.Button("send");st=gr.Textbox(interactive=False,label="status") | |
demo.load(_fetch,outputs=[prm,txt,tme]) | |
rbtn.click(_fetch,outputs=[prm,txt,tme]) | |
send.click(_sg,inputs=[full,idea],outputs=[st,full,idea]) | |
if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860) | |