File size: 2,897 Bytes
e0c069c
 
 
 
271fdae
e0c069c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71414fa
 
 
 
 
e0c069c
71414fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c069c
 
 
 
 
 
71414fa
e0c069c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import logging
import requests
from time import perf_counter, sleep
from memory_manager import embed_and_store, retrieve_relevant

# Agent prompts
PROMPTS = {
    "Initiator": "You are the Discussion Initiator...",
    "Responder": "You are the Critical Responder...",
    "Guardian": "You are the Depth Guardian...",
    "Provocateur": "You are the Cross-Disciplinary Provocateur...",
    "Cultural": "You are the Cultural Perspective...",
    "Judge": "You are the Impartial Judge..."
}

CHAT_MODEL = os.environ.get("CHAT_MODEL", "HuggingFaceH4/zephyr-7b-beta")
HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

def safe_chat(system_prompt: str, history: list, temperature: float = 0.7) -> str:
    """Call HF inference API with fallback formatting."""
    start = perf_counter()
    chat_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in history])
    full_prompt = f"{system_prompt}\n\n{chat_history}\n\nAssistant:"

    payload = {
        "inputs": full_prompt,
        "parameters": {"max_new_tokens": 300, "temperature": temperature}
    }
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}

    try:
        resp = requests.post(
            f"https://api-inference.huggingface.co/models/{CHAT_MODEL}",
            json=payload,
            headers=headers,
            timeout=60
        )
        if resp.status_code == 200:
            data = resp.json()
            text = data[0].get("generated_text", "").strip() if isinstance(data, list) else data.get("generated_text", "").strip()
        elif resp.status_code == 503:
            logging.warning("Model loading… retrying after 15s.")
            sleep(15)
            return safe_chat(system_prompt, history, temperature)
        else:
            logging.error(f"HF error {resp.status_code}: {resp.text}")
            text = f"⚠️ API Error {resp.status_code}"
    except Exception as e:
        logging.error(f"safe_chat exception: {e}")
        text = f"⚠️ System Error: {e}"

    elapsed = perf_counter() - start
    logging.info(f"safe_chat: {elapsed:.3f}s for prompt '{system_prompt[:30]}…'")
    return text

def step_turn(conversation: list, turn: int, topic: str, params: dict) -> list:
    """Advance one turn of the multi-agent conversation."""
    sequence = ["Initiator", "Responder", "Guardian", "Provocateur", "Cultural"]
    agent = sequence[turn % len(sequence)]
    prompt = PROMPTS.get(agent, "")
    history = [{"role": "user", "content": msg['text']} for msg in conversation[-5:] if msg['agent'] != "System"]
    response = safe_chat(prompt, history, temperature=params[agent]['creativity'])
    embed_and_store(response, agent, topic)
    conversation.append({"agent": agent, "text": response, "turn": turn + 1})
    return conversation