|
from typing import List, Optional |
|
|
|
import pydantic |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7] |
|
|
|
|
|
class PuzzleDatasetMetadata(pydantic.BaseModel): |
|
pad_id: int |
|
ignore_label_id: Optional[int] |
|
blank_identifier_id: int |
|
|
|
vocab_size: int |
|
seq_len: int |
|
num_puzzle_identifiers: int |
|
|
|
total_groups: int |
|
mean_puzzle_examples: float |
|
|
|
sets: List[str] |
|
|
|
|
|
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: |
|
"""8 dihedral symmetries by rotate, flip and mirror""" |
|
|
|
if tid == 0: |
|
return arr |
|
elif tid == 1: |
|
return np.rot90(arr, k=1) |
|
elif tid == 2: |
|
return np.rot90(arr, k=2) |
|
elif tid == 3: |
|
return np.rot90(arr, k=3) |
|
elif tid == 4: |
|
return np.fliplr(arr) |
|
elif tid == 5: |
|
return np.flipud(arr) |
|
elif tid == 6: |
|
return arr.T |
|
elif tid == 7: |
|
return np.fliplr(np.rot90(arr, k=1)) |
|
else: |
|
return arr |
|
|
|
|
|
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: |
|
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid]) |
|
|