In [None]:
import os
import json
from glob import glob
import hashlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import torch
import torch.nn.functional as F
import numpy as np
from numba import njit

from dataset.common import inverse_dihedral_transform


DATASET_PATH = "data/arc-aug-1000" # ARC-1
# DATASET_PATH = "data/arc-2-aug-1000" # ARC-2

CHECKPOINT_PATH = "checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456"


PAD_PUZZLE_IDENTIFIER = 0

# Visualization
ARC_COLOR_MAP = mcolors.ListedColormap([
 "#000000", # symbol_0: black
 "#0074D9", # symbol_1: blue
 "#FF4136", # symbol_2: red
 "#2ECC40", # symbol_3: green
 "#FFDC00", # symbol_4: yellow
 "#AAAAAA", # symbol_5: grey
 "#F012BE", # symbol_6: fuschia
 "#FF851B", # symbol_7: orange
 "#7FDBFF", # symbol_8: teal
 "#870C25" # symbol_9: brown
])

In [None]:
def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):
 # Load puzzle identifiers
 with open(os.path.join(dataset_path, "identifiers.json"), "r") as f:
 identifier_map = json.load(f)
 
 # Load preds
 all_preds = {}
 for filename in glob(f"{checkpoint_path}_all_preds.*"):
 preds = torch.load(filename)
 for k, v in preds.items():
 all_preds.setdefault(k, [])
 all_preds[k].append(v)
 
 del preds

 all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
 
 # Remove paddings
 mask = all_preds["puzzle_identifiers"] != PAD_PUZZLE_IDENTIFIER
 all_preds = {k: v[mask] for k, v in all_preds.items()}

 return identifier_map, all_preds


def inverse_aug(name: str, grid: np.ndarray):
 if "_" not in name:
 return grid

 trans_id, perm = name.split("_")[-2:]
 trans_id = int(trans_id[1:]) # Remove "t" letter
 inv_perm = np.argsort(list(perm))
 
 return inv_perm[inverse_dihedral_transform(grid, trans_id)]


def grid_hash(grid: np.ndarray):
 return hash((grid.tobytes(), grid.shape))


@njit
def crop(grid: np.ndarray):
 # Find maximum-sized rectangle without any EOS token inside.
 grid = grid.reshape(30, 30)

 max_area = 0
 max_size = (0, 0)
 nr, nc = grid.shape
 
 num_c = nc
 for num_r in range(1, nr + 1):
 # Scan for maximum c
 for c in range(1, num_c + 1):
 x = grid[num_r - 1, c - 1]
 if (x < 2) | (x > 11):
 num_c = c - 1
 break
 
 area = num_r * num_c
 if area > max_area:
 max_area = area
 max_size = (num_r, num_c)

 return grid[:max_size[0], :max_size[1]] - 2


def test(visualize, Ks=[1, 2, 10, 100, 1000]):
 identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)
 
 global_hmap = {}
 
 # Get puzzles and corresponding answers
 puzzle_labels = {}
 for identifier, input, label in zip(all_preds["puzzle_identifiers"], all_preds["inputs"], all_preds["labels"]):
 name = identifier_map[identifier]
 if "_" not in name: # Not-augmented
 puzzle_labels.setdefault(name, {})
 
 input = crop(input.numpy())
 label = crop(label.numpy())

 input_hash = grid_hash(input)
 label_hash = grid_hash(label)

 global_hmap[input_hash] = input
 global_hmap[label_hash] = label

 assert input_hash not in puzzle_labels[name]
 puzzle_labels[name][input_hash] = label_hash
 
 print ("Number of puzzles", len(puzzle_labels))
 
 # Argmax prediction
 preds = all_preds["logits"].argmax(-1)

 # Collate
 pred_answers = {}
 for identifier, input, pred, q in zip(all_preds["puzzle_identifiers"], all_preds["inputs"], preds, all_preds["q_halt_logits"].sigmoid()):
 name = identifier_map[identifier]
 orig_name = name.split("_")[0]
 
 input = input.numpy()
 input_hash = grid_hash(inverse_aug(name, crop(input)))
 assert input_hash in puzzle_labels[orig_name]
 
 pred = inverse_aug(name, crop(pred.numpy()))
 pred_hash = grid_hash(pred)
 global_hmap[pred_hash] = pred
 
 pred_answers.setdefault(orig_name, {})
 pred_answers[orig_name].setdefault(input_hash, [])
 pred_answers[orig_name][input_hash].append((pred_hash, q.item()))

 # test-1
 if visualize:
 num_figs = sum(len(tests) for name, tests in puzzle_labels.items())
 fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))
 
 fig_id = 0
 
 correct = [0 for _ in range(len(Ks))]
 for name, tests in puzzle_labels.items():
 num_test_correct = [0 for _ in range(len(Ks))]
 for input_hash, label_hash in tests.items():
 p = pred_answers[name][input_hash]
 p_map = {}
 
 for h, q in p:
 p_map.setdefault(h, [0, 0])
 p_map[h][0] += 1
 p_map[h][1] += q
 
 for h, stats in p_map.items():
 stats[1] /= stats[0]
 
 p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)

 # 2-vote
 for i, k in enumerate(Ks):
 ok = False
 for h, stats in p_map[:k]:
 ok |= h == label_hash
 
 num_test_correct[i] += ok

 if visualize:
 # Show input and ground truth
 axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)
 axes[fig_id, 0].set_title(f"{name}\nInput")
 axes[fig_id, 0].axis('off')
 
 axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)
 axes[fig_id, 1].set_title(f"{name}\nAnswer")
 axes[fig_id, 1].axis('off')
 
 trial_id = 2
 for h, stats in p_map[:2]:
 ans = global_hmap[h]
 
 axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)
 axes[fig_id, trial_id].set_title(f"{name}\nTrial {trial_id}")
 axes[fig_id, trial_id].axis('off')
 
 trial_id += 1
 
 fig_id += 1
 
 # Total correctness
 for i in range(len(Ks)):
 correct[i] += num_test_correct[i] == len(tests)

 for i, k in enumerate(Ks):
 print (f"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%")


test(visualize=False)