File size: 19,475 Bytes
1d1182e
6a869ae
1d1182e
bd61488
1d1182e
 
6a869ae
 
bd61488
 
6a869ae
bd61488
 
 
6a869ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1182e
 
 
6a869ae
 
 
 
 
 
 
 
 
 
 
1d1182e
6a869ae
 
 
1d0a230
 
6a869ae
 
 
 
 
 
 
 
 
 
 
bd61488
 
6a869ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd61488
 
6a869ae
 
 
 
bd61488
6a869ae
 
 
bd61488
6a869ae
 
 
 
 
 
bd61488
6a869ae
 
bd61488
6a869ae
 
 
 
 
 
bd61488
6a869ae
 
 
 
1d1182e
6a869ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd61488
6a869ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58d7c9e
6a869ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1182e
6a869ae
 
 
 
 
1d1182e
6a869ae
 
 
 
 
1189ea8
6a869ae
 
 
 
 
 
 
1189ea8
6a869ae
bd61488
6a869ae
 
 
bd61488
6a869ae
1189ea8
6a869ae
 
 
 
1d1182e
6a869ae
 
 
 
 
1d1182e
6a869ae
 
 
 
 
 
 
 
 
 
 
 
1d1182e
 
 
6a869ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer # Using AutoModel for flexibility
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
import numpy as np
import gradio as gr
import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend for Matplotlib in server environments
import matplotlib.pyplot as plt
import seaborn as sns
# import networkx as nx # Defined build_similarity_graph but not used in output
import io
import base64

# --- Model and Tokenizer Setup ---
# Ensure model_name is one you have access to or is public
# For local models, provide the path.
DEFAULT_MODEL_NAME = "EleutherAI/gpt-neo-1.3B"
FALLBACK_MODEL_NAME = "gpt2" # In case the preferred model fails

try:
    print(f"Attempting to load model: {DEFAULT_MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(DEFAULT_MODEL_NAME)
    print(f"Successfully loaded model: {DEFAULT_MODEL_NAME}")
except OSError as e:
    print(f"Error loading model {DEFAULT_MODEL_NAME}. Error: {e}")
    print(f"Falling back to {FALLBACK_MODEL_NAME}.")
    tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL_NAME)
    print(f"Successfully loaded fallback model: {FALLBACK_MODEL_NAME}")

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

# --- Configuration ---
# Model's actual context window (e.g., 2048 for GPT-Neo, 1024 for GPT-2)
MODEL_CONTEXT_WINDOW = tokenizer.model_max_length if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length is not None else model.config.max_position_embeddings
print(f"Model context window: {MODEL_CONTEXT_WINDOW} tokens.")

# Max tokens for prompt trimming (input to tokenizer for generate)
PROMPT_TRIM_MAX_TOKENS = min(MODEL_CONTEXT_WINDOW - 200, 1800) # Reserve ~200 for generation, cap at 1800
# Max new tokens to generate
MAX_GEN_LENGTH = 150 # Increased slightly for more elaborate responses


# --- Debug Logging ---
debug_log_accumulator = []

def debug(msg):
    print(msg) # For server-side console
    debug_log_accumulator.append(str(msg)) # For Gradio UI output

# --- Core Functions ---
def trim_prompt_if_needed(prompt_text, max_tokens_for_trimming=PROMPT_TRIM_MAX_TOKENS):
    """Trims the prompt from the beginning if it exceeds max_tokens_for_trimming."""
    tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
    if len(tokens) > max_tokens_for_trimming:
        debug(f"[!] Prompt trimming: Original {len(tokens)} tokens, "
              f"trimmed to {max_tokens_for_trimming} (from the end, keeping recent context).")
        tokens = tokens[-max_tokens_for_trimming:] # Keep the most recent part of the prompt
    return tokenizer.decode(tokens)

def generate_text_response(prompt_text, generation_length=MAX_GEN_LENGTH):
    """Generates text response ensuring prompt + generation fits context window."""
    # Trim the input prompt first to adhere to PROMPT_TRIM_MAX_TOKENS
    # This ensures the base prompt itself isn't excessively long before adding generation instructions.
    # Note: The prompt_text here is already the *constructed* prompt (e.g., "Elaborate on: ...")
    # For very long base statements, they might get trimmed by this.
    # This function itself doesn't need to call trim_prompt_if_needed if the calling function already does.
    # However, it's a good safety.
    # Let's assume prompt_text is the final prompt ready for tokenization.

    debug(f"Generating response for prompt (length {len(prompt_text.split())} words):\n'{prompt_text[:300]}...'") # Log truncated prompt

    inputs = tokenizer(prompt_text, return_tensors="pt", truncation=False).to(device) # Do not truncate here, will be handled by max_length
    input_token_length = len(inputs["input_ids"][0])

    # Safety check: if input_token_length itself is already > MODEL_CONTEXT_WINDOW due to some miscalculation before this call
    if input_token_length >= MODEL_CONTEXT_WINDOW:
        debug(f"[!!!] FATAL: Input prompt ({input_token_length} tokens) already exceeds/matches model context window ({MODEL_CONTEXT_WINDOW}) before generation. Trimming input drastically.")
        # Trim the input_ids directly
        inputs["input_ids"] = inputs["input_ids"][:, -MODEL_CONTEXT_WINDOW+generation_length+10] # Keep last part allowing some generation
        inputs["attention_mask"] = inputs["attention_mask"][:, -MODEL_CONTEXT_WINDOW+generation_length+10]
        input_token_length = len(inputs["input_ids"][0])
        if input_token_length >= MODEL_CONTEXT_WINDOW - generation_length : # Still too long
            return "[Input prompt too long, even after emergency trim]"


    max_length_for_generate = min(input_token_length + generation_length, MODEL_CONTEXT_WINDOW)

    # Ensure we are actually generating new tokens
    if max_length_for_generate <= input_token_length :
        debug(f"[!] Warning: Prompt length ({input_token_length}) is too close to model context window ({MODEL_CONTEXT_WINDOW}). "
              f"Adjusting to generate a few tokens if possible.")
        max_length_for_generate = input_token_length + min(generation_length, 10) # Try to generate at least a few, up to 10
        if max_length_for_generate > MODEL_CONTEXT_WINDOW:
             return "[Prompt too long to generate meaningful response]"

    try:
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length_for_generate,
            pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 50256, # GPT2 EOS
            do_sample=True,
            temperature=0.8, # Slightly more deterministic
            top_p=0.9,
            repetition_penalty=1.1, # Slightly stronger penalty
        )
        # Decode only the newly generated tokens
        generated_tokens = outputs[0][input_token_length:]
        result_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

        debug(f"Generated response text (length {len(result_text.split())} words):\n'{result_text[:300]}...'")
        return result_text if result_text else "[Empty Response]"
    except Exception as e:
        debug(f"[!!!] Error during text generation: {e}")
        return "[Generation Error]"

def calculate_similarity(text_a, text_b):
    """Calculates cosine similarity between mean embeddings of two texts."""
    invalid_texts = ["[Empty Response]", "[Generation Error]", "[Prompt too long to generate meaningful response]", "[Input prompt too long, even after emergency trim]"]
    if not text_a or not text_a.strip() or not text_b or not text_b.strip() \
       or text_a in invalid_texts or text_b in invalid_texts:
        debug(f"Similarity calculation skipped for invalid/empty texts.")
        return 0.0

    # Use model's embedding layer (wte for GPT-like models)
    embedding_layer = model.get_input_embeddings()

    with torch.no_grad():
        # Truncate inputs for embedding calculation to fit model context window
        tokens_a = tokenizer(text_a, return_tensors="pt", truncation=True, max_length=MODEL_CONTEXT_WINDOW).to(device)
        tokens_b = tokenizer(text_b, return_tensors="pt", truncation=True, max_length=MODEL_CONTEXT_WINDOW).to(device)

        if tokens_a.input_ids.size(1) == 0 or tokens_b.input_ids.size(1) == 0:
            debug("Similarity calculation skipped: tokenization resulted in empty input_ids.")
            return 0.0

        emb_a = embedding_layer(tokens_a.input_ids).mean(dim=1)
        emb_b = embedding_layer(tokens_b.input_ids).mean(dim=1)

    score = float(cosine_similarity(emb_a.cpu().numpy(), emb_b.cpu().numpy())[0][0])
    # debug(f"Similarity score: {score:.4f}") # Debug log now includes texts, so this is redundant
    return score

def generate_similarity_heatmap(texts_list, custom_labels, title="Semantic Similarity Heatmap"):
    if not texts_list or len(texts_list) < 2:
        debug("Not enough texts to generate a heatmap.")
        return ""

    num_texts = len(texts_list)
    sim_matrix = np.zeros((num_texts, num_texts))

    for i in range(num_texts):
        for j in range(num_texts):
            if i == j:
                sim_matrix[i, j] = 1.0
            elif i < j: # Calculate only upper triangle
                sim = calculate_similarity(texts_list[i], texts_list[j])
                sim_matrix[i, j] = sim
                sim_matrix[j, i] = sim # Symmetric matrix

    try:
        fig_width = max(6, num_texts * 0.7)
        fig_height = max(5, num_texts * 0.6)
        fig, ax = plt.subplots(figsize=(fig_width, fig_height))

        sns.heatmap(sim_matrix, annot=True, cmap="viridis", fmt=".2f", ax=ax,
                    xticklabels=custom_labels, yticklabels=custom_labels, annot_kws={"size": 8})
        ax.set_title(title, fontsize=12)
        plt.xticks(rotation=45, ha="right", fontsize=9)
        plt.yticks(rotation=0, fontsize=9)
        plt.tight_layout()

        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight')
        plt.close(fig)
        buf.seek(0)
        img_base64 = base64.b64encode(buf.read()).decode('utf-8')
        return f"<img src='data:image/png;base64,{img_base64}' alt='{title}' style='max-width:100%; height:auto;'/>"
    except Exception as e:
        debug(f"[!!!] Error generating heatmap: {e}")
        return "Error generating heatmap."


def perform_text_clustering(texts_list, custom_labels, num_clusters=2):
    if not texts_list or len(texts_list) < num_clusters :
        debug("Not enough texts for clustering or texts_list is empty.")
        return {label: "N/A" for label in custom_labels}

    embedding_layer = model.get_input_embeddings()
    valid_embeddings = []
    valid_indices = [] # Keep track of original indices of valid texts

    with torch.no_grad():
        for idx, text_item in enumerate(texts_list):
            invalid_markers = ["[Empty Response]", "[Generation Error]", "[Prompt too long", "[Input prompt too long"]
            if not text_item or not text_item.strip() or any(marker in text_item for marker in invalid_markers):
                debug(f"Skipping text at index {idx} for embedding due to invalid content: '{text_item[:50]}...'")
                continue # Skip invalid texts

            tokens = tokenizer(text_item, return_tensors="pt", truncation=True, max_length=MODEL_CONTEXT_WINDOW).to(device)
            if tokens.input_ids.size(1) == 0:
                 debug(f"Skipping text at index {idx} due to empty tokenization: '{text_item[:50]}...'")
                 continue

            emb = embedding_layer(tokens.input_ids).mean(dim=1)
            valid_embeddings.append(emb.cpu().numpy().squeeze())
            valid_indices.append(idx)

    if not valid_embeddings or len(valid_embeddings) < num_clusters:
        debug("Not enough valid texts were embedded for clustering.")
        return {label: "N/A" for label in custom_labels}

    embeddings_np = np.array(valid_embeddings)

    cluster_results = {label: "N/A" for label in custom_labels} # Initialize all as N/A

    try:
        # Adjust num_clusters if less valid samples than requested clusters
        actual_num_clusters = min(num_clusters, len(valid_embeddings))
        if actual_num_clusters < 2 and len(valid_embeddings) > 0 : # If only one valid sample, or num_clusters becomes 1
            debug(f"Only {len(valid_embeddings)} valid sample(s). Assigning all to Cluster 0.")
            predicted_labels = [0] * len(valid_embeddings)
        elif actual_num_clusters < 2: # No valid samples
             debug("No valid samples to cluster.")
             return cluster_results
        else:
            kmeans = KMeans(n_clusters=actual_num_clusters, random_state=42, n_init='auto')
            predicted_labels = kmeans.fit_predict(embeddings_np)

        # Map predicted labels back to original text indices
        for i, original_idx in enumerate(valid_indices):
            cluster_results[custom_labels[original_idx]] = f"C{predicted_labels[i]}"
        return cluster_results

    except Exception as e:
        debug(f"[!!!] Error during clustering: {e}")
        return {label: "Error" for label in custom_labels}


# --- Main EAL Unfolding Logic ---
def run_eal_dual_unfolding(num_iterations):
    I_trace_texts, not_I_trace_texts = [], []
    delta_S_I_values, delta_S_not_I_values, delta_S_cross_values = [], [], []

    debug_log_accumulator.clear()
    ui_log_entries = []

    # Initial base statement for the I-trace for Iteration 0
    # This is the statement "I" will elaborate on in the first step.
    # Using a more concrete initial statement for "I"
    current_I_basis_statement = "I am a complex system designed for text processing, capable of generating human-like language."

    for i in range(num_iterations):
        ui_log_entries.append(f"--- Iteration {i} ---")
        debug(f"\n=== Iteration {i} ===")

        # === I-Trace (Self-Reflection) ===
        # Prompt for I-trace: Elaborate on its *previous* statement (or initial statement for i=0)
        prompt_for_I_trace = f"A system previously stated: \"{current_I_basis_statement}\"\n" + \
                             "Task: Elaborate on this statement, exploring its implications and nuances while maintaining coherence."
        ui_log_entries.append(f"[Prompt for I{i}]:\n{prompt_for_I_trace[:500]}...\n") # Log truncated prompt

        generated_I_text = generate_text_response(prompt_for_I_trace)
        I_trace_texts.append(generated_I_text)
        ui_log_entries.append(f"[I{i} Response]:\n{generated_I_text}\n")

        # Update basis for the next I-elaboration: the text just generated
        current_I_basis_statement = generated_I_text

        # === ¬I-Trace (Antithesis/Contradiction) ===
        # ¬I always attempts to refute the MOST RECENT statement from the I-trace
        statement_to_refute_for_not_I = generated_I_text
        prompt_for_not_I_trace = f"Consider the following claim made by a system: \"{statement_to_refute_for_not_I}\"\n" + \
                                 "Task: Present a strong, fundamental argument that contradicts or refutes this specific claim. Explain why it could be false, problematic, or based on flawed assumptions."
        ui_log_entries.append(f"[Prompt for ¬I{i}]:\n{prompt_for_not_I_trace[:500]}...\n") # Log truncated prompt

        generated_not_I_text = generate_text_response(prompt_for_not_I_trace)
        not_I_trace_texts.append(generated_not_I_text)
        ui_log_entries.append(f"[¬I{i} Response]:\n{generated_not_I_text}\n")

        # === ΔS (Similarity) Calculations ===
        if i > 0:
            sim_I_prev_curr = calculate_similarity(I_trace_texts[i-1], I_trace_texts[i])
            sim_not_I_prev_curr = calculate_similarity(not_I_trace_texts[i-1], not_I_trace_texts[i])
            sim_cross_I_not_I_curr = calculate_similarity(I_trace_texts[i], not_I_trace_texts[i]) # Between current I and current ¬I

            delta_S_I_values.append(sim_I_prev_curr)
            delta_S_not_I_values.append(sim_not_I_prev_curr)
            delta_S_cross_values.append(sim_cross_I_not_I_curr)
        else: # i == 0 (first iteration)
            delta_S_I_values.append(None)
            delta_S_not_I_values.append(None)
            sim_cross_initial = calculate_similarity(I_trace_texts[0], not_I_trace_texts[0])
            delta_S_cross_values.append(sim_cross_initial)

    # --- Post-loop Analysis & Output Formatting ---
    all_generated_texts = I_trace_texts + not_I_trace_texts
    # Create meaningful labels for heatmap and clustering based on I_n and ¬I_n
    text_labels_for_analysis = [f"I{k}" for k in range(num_iterations)] + \
                               [f"¬I{k}" for k in range(num_iterations)]

    cluster_assignments_map = perform_text_clustering(all_generated_texts, text_labels_for_analysis, num_clusters=2)

    I_out_formatted_lines = []
    for k in range(num_iterations):
        cluster_label = cluster_assignments_map.get(f"I{k}", "N/A")
        I_out_formatted_lines.append(f"I{k} [{cluster_label}]:\n{I_trace_texts[k]}")
    I_out_formatted = "\n\n".join(I_out_formatted_lines)

    not_I_out_formatted_lines = []
    for k in range(num_iterations):
        cluster_label = cluster_assignments_map.get(f"¬I{k}", "N/A")
        not_I_out_formatted_lines.append(f"¬I{k} [{cluster_label}]:\n{not_I_trace_texts[k]}")
    not_I_out_formatted = "\n\n".join(not_I_out_formatted_lines)

    delta_S_summary_lines = []
    for k in range(num_iterations):
        ds_i_str = f"{delta_S_I_values[k]:.4f}" if delta_S_I_values[k] is not None else "N/A"
        ds_not_i_str = f"{delta_S_not_I_values[k]:.4f}" if delta_S_not_I_values[k] is not None else "N/A"
        ds_cross_str = f"{delta_S_cross_values[k]:.4f}"
        delta_S_summary_lines.append(f"Iter {k}: ΔS(I)={ds_i_str},  ΔS(¬I)={ds_not_i_str},  ΔS_Cross(I↔¬I)={ds_cross_str}")
    delta_S_summary_output = "\n".join(delta_S_summary_lines)

    debug_log_output = "\n".join(debug_log_accumulator)

    heatmap_html_output = generate_similarity_heatmap(all_generated_texts,
                                                    custom_labels=text_labels_for_analysis,
                                                    title=f"Similarity Matrix (All Texts - {num_iterations} Iterations)")

    return I_out_formatted, not_I_out_formatted, delta_S_summary_output, debug_log_output, heatmap_html_output

# --- Gradio Interface Definition ---
eal_interface = gr.Interface(
    fn=run_eal_dual_unfolding,
    inputs=gr.Slider(minimum=2, maximum=5, value=3, step=1, label="Number of EAL Iterations"), # Max 5 for performance
    outputs=[
        gr.Textbox(label="I-Trace (Self-Reflection with Cluster)", lines=12, interactive=False),
        gr.Textbox(label="¬I-Trace (Antithesis with Cluster)", lines=12, interactive=False),
        gr.Textbox(label="ΔS Similarity Trace Summary", lines=7, interactive=False),
        gr.Textbox(label="Detailed Debug Log (Prompts, Responses, Errors)", lines=10, interactive=False),
        gr.HTML(label="Overall Semantic Similarity Heatmap")
    ],
    title="EAL LLM Identity Analyzer: Self-Reflection vs. Antithesis",
    description=(
        "This application explores emergent identity in a Large Language Model (LLM) using Entropic Attractor Logic (EAL) inspired principles. "
        "It runs two parallel conversational traces: \n"
        "1. **I-Trace:** The model elaborates on its evolving self-concept statement.\n"
        "2. **¬I-Trace:** The model attempts to refute/contradict the latest statement from the I-Trace.\n\n"
        "**ΔS Values:** Cosine similarity between consecutive statements in each trace, and cross-similarity between I and ¬I at each iteration. High values (near 1.0) suggest semantic stability or high similarity.\n"
        "**Clustering [Cx]:** Assigns each generated text to one of two semantic clusters (C0 or C1) to see if I-Trace and ¬I-Trace form distinct groups.\n"
        "**Heatmap:** Visualizes pair-wise similarity across all generated texts (I-trace and ¬I-trace combined)."
    ),
    allow_flagging='never',
    # examples=[[3],[5]] # Example number of iterations
)

if __name__ == "__main__":
    print("Starting Gradio App...")
    eal_interface.launch()