Spaces:
Running
Running
import os | |
import json | |
import random | |
import threading | |
import logging | |
import sqlite3 | |
from datetime import datetime | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from sentence_transformers import SentenceTransformer, util | |
# Logging setup | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load Oracle model (FP32, CPU-only) | |
logger.info("Loading Oracle model...") | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") | |
model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-3.1-8B-Instruct", | |
torch_dtype=torch.float32, | |
device_map="cpu" | |
) | |
model.eval() | |
# Load SentenceTransformer for semantic similarity | |
logger.info("Loading SentenceTransformer model...") | |
st_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Database setup (SQLite) | |
DB_PATH = "game_data.db" | |
conn = sqlite3.connect(DB_PATH, check_same_thread=False) | |
c = conn.cursor() | |
c.execute(""" | |
CREATE TABLE IF NOT EXISTS rounds ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
timestamp TEXT, | |
prompt TEXT, | |
full_guess TEXT, | |
idea_guess TEXT, | |
completion TEXT, | |
score_full INTEGER, | |
score_idea INTEGER, | |
round_points INTEGER | |
) | |
""") | |
conn.commit() | |
# Load prompts from JSON | |
PROMPTS_PATH = "oracle_prompts.json" | |
with open(PROMPTS_PATH, 'r') as f: | |
PROMPTS = json.load(f) | |
# Helper functions | |
def get_next_prompt(state): | |
if not state["prompts"]: | |
prompts = PROMPTS.copy() | |
random.shuffle(prompts) | |
state["prompts"] = prompts | |
state["used"] = [] | |
prompt = state["prompts"].pop(0) | |
state["used"].append(prompt) | |
state["round"] += 1 | |
return prompt | |
def compute_score(guess, completion): | |
if not guess.strip(): | |
return 0 | |
emb_guess = st_model.encode(guess, convert_to_tensor=True) | |
emb_comp = st_model.encode(completion, convert_to_tensor=True) | |
cos_sim = util.pytorch_cos_sim(emb_guess, emb_comp).item() | |
if cos_sim > 0.9: | |
return 5 | |
elif cos_sim > 0.7: | |
return 3 | |
elif cos_sim > 0.5: | |
return 1 | |
else: | |
return 0 | |
def log_round(prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points): | |
ts = datetime.utcnow().isoformat() | |
c.execute( | |
"INSERT INTO rounds (timestamp, prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", | |
(ts, prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points) | |
) | |
conn.commit() | |
logger.info(f"Round logged at {ts}") | |
def play_round(full_guess, idea_guess, state): | |
prompt = state.get("current_prompt", "") | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
def generate(): | |
model.generate( | |
input_ids=input_ids, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.8, | |
streamer=streamer | |
) | |
thread = threading.Thread(target=generate) | |
thread.start() | |
completion = "" | |
for token in streamer: | |
completion += token | |
yield completion, "", "" | |
score_full = compute_score(full_guess, completion) | |
score_idea = compute_score(idea_guess, completion) | |
round_points = score_full + score_idea | |
state["score"] += round_points | |
log_round(prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points) | |
score_text = f"Full Guess: {score_full} pts | Idea Guess: {score_idea} pts | Round Total: {round_points} pts" | |
reflection = "🔮 The Oracle ponders your insights..." | |
if state["round"] >= 5 and state["score"] >= 15: | |
secret = random.choice([p for p in PROMPTS if p not in state["used"]]) | |
reflection += f"\n\n✨ **Secret Oracle Prompt:** {secret}" | |
yield completion, score_text, reflection, state["score"] | |
def next_round_fn(state): | |
prompt = get_next_prompt(state) | |
state["current_prompt"] = prompt | |
return prompt, "", "", "", "", "", state["score"] | |
# Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
state = gr.State({"prompts": [], "used": [], "round": 0, "score": 0, "current_prompt": ""}) | |
gr.Markdown("⚠️ **Your input and the Oracle’s response will be stored for AI training and research. By playing, you consent to this.**") | |
prompt_display = gr.Markdown("", elem_id="prompt_display") | |
with gr.Row(): | |
full_guess = gr.Textbox(label="🧠 Exact Full Completion Guess") | |
idea_guess = gr.Textbox(label="💡 General Idea Guess") | |
submit = gr.Button("Submit Guess") | |
completion_box = gr.Textbox(label="Oracle's Completion", interactive=False) | |
score_box = gr.Textbox(label="Score", interactive=False) | |
reflection_box = gr.Textbox(label="Mystical Reflection", interactive=False) | |
next_btn = gr.Button("Next Round") | |
total_score_display = gr.Textbox(label="Total Score", interactive=False) | |
next_btn.click(next_round_fn, inputs=state, outputs=[prompt_display, full_guess, idea_guess, completion_box, score_box, reflection_box, total_score_display]) | |
submit.click(play_round, inputs=[full_guess, idea_guess, state], outputs=[completion_box, score_box, reflection_box, total_score_display]) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |