|
import unittest |
|
import torch |
|
import networkx as nx |
|
from graph_construction import extract_patches, build_graph_from_patches, build_graph_data_from_patches |
|
|
|
class TestGraphConstruction(unittest.TestCase): |
|
def test_extract_patches_shape(self): |
|
|
|
feature_map = torch.randn(2, 16, 32, 32) |
|
patches = extract_patches(feature_map, patch_size=(4,4)) |
|
|
|
|
|
self.assertEqual(patches.shape, (2, 64, 16, 4, 4)) |
|
|
|
def test_build_graph_from_patches_graph_structure(self): |
|
feature_map = torch.randn(1, 16, 32, 32) |
|
G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4)) |
|
|
|
G = G_batch[0] |
|
|
|
self.assertEqual(len(G.nodes), 64) |
|
|
|
|
|
|
|
node_index = 9 |
|
self.assertTrue(len(list(G.neighbors(node_index))) > 0) |
|
|
|
def test_build_graph_data_from_patches_conversion(self): |
|
feature_map = torch.randn(2, 16, 32, 32) |
|
G_batch, patches = build_graph_from_patches(feature_map, patch_size=(4,4)) |
|
data_list = build_graph_data_from_patches(G_batch, patches) |
|
self.assertEqual(len(data_list), 2) |
|
|
|
self.assertEqual(data_list[0].x.shape[1], 16*4*4) |
|
|
|
self.assertTrue(data_list[0].edge_index.shape[1] > 0) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|