File size: 1,800 Bytes
039647a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import unittest
import torch
import networkx as nx
from graph_construction import extract_patches, build_graph_from_patches, build_graph_data_from_patches

class TestGraphConstruction(unittest.TestCase):
    def test_extract_patches_shape(self):
        # Create a dummy feature map: B=2, C=16, H=32, W=32
        feature_map = torch.randn(2, 16, 32, 32)
        patches = extract_patches(feature_map, patch_size=(4,4))
        # Check dimensions: after extraction, 
        # number_of_patches = (H/4)*(W/4) = 8*8=64 per image, total 2*64=128
        self.assertEqual(patches.shape, (2, 64, 16, 4, 4))

    def test_build_graph_from_patches_graph_structure(self):
        feature_map = torch.randn(1, 16, 32, 32)
        G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
        # 1 image => G_batch[0] is the graph
        G = G_batch[0]
        # We have 64 patches
        self.assertEqual(len(G.nodes), 64)
        # Check if edges exist (8-neighborhood). 
        # Interior nodes should have edges to neighbors.
        # Just check a random node in the middle
        node_index = 9 # assuming row=1, col=1 in an 8x8 grid
        self.assertTrue(len(list(G.neighbors(node_index))) > 0)

    def test_build_graph_data_from_patches_conversion(self):
        feature_map = torch.randn(2, 16, 32, 32)
        G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4))
        data_list = build_graph_data_from_patches(G_batch, patches)
        self.assertEqual(len(data_list), 2)
        # Check node feature shape
        self.assertEqual(data_list[0].x.shape[1], 16*4*4)  # C * patch_h * patch_w = 16*4*4=256
        # Check edges are present
        self.assertTrue(data_list[0].edge_index.shape[1] > 0)

if __name__ == '__main__':
    unittest.main()