File size: 1,800 Bytes
039647a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import unittest
import torch
import networkx as nx
from graph_construction import extract_patches, build_graph_from_patches, build_graph_data_from_patches
class TestGraphConstruction(unittest.TestCase):
def test_extract_patches_shape(self):
# Create a dummy feature map: B=2, C=16, H=32, W=32
feature_map = torch.randn(2, 16, 32, 32)
patches = extract_patches(feature_map, patch_size=(4,4))
# Check dimensions: after extraction,
# number_of_patches = (H/4)*(W/4) = 8*8=64 per image, total 2*64=128
self.assertEqual(patches.shape, (2, 64, 16, 4, 4))
def test_build_graph_from_patches_graph_structure(self):
feature_map = torch.randn(1, 16, 32, 32)
G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
# 1 image => G_batch[0] is the graph
G = G_batch[0]
# We have 64 patches
self.assertEqual(len(G.nodes), 64)
# Check if edges exist (8-neighborhood).
# Interior nodes should have edges to neighbors.
# Just check a random node in the middle
node_index = 9 # assuming row=1, col=1 in an 8x8 grid
self.assertTrue(len(list(G.neighbors(node_index))) > 0)
def test_build_graph_data_from_patches_conversion(self):
feature_map = torch.randn(2, 16, 32, 32)
G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
data_list = build_graph_data_from_patches(G_batch, patches)
self.assertEqual(len(data_list), 2)
# Check node feature shape
self.assertEqual(data_list[0].x.shape[1], 16*4*4) # C * patch_h * patch_w = 16*4*4=256
# Check edges are present
self.assertTrue(data_list[0].edge_index.shape[1] > 0)
if __name__ == '__main__':
unittest.main()
|