File size: 2,360 Bytes
c7995e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np

import networkx as nx
from networkx.utils import UnionFind

from typing import Optional
import torch
from torch import Tensor

from torch_sparse import SparseTensor
from scipy.sparse import csr_matrix
from math import pi as PI
import torch.nn.functional as F
def unique(sequence):
    seen = set()
    return [x for x in sequence if not (x in seen or seen.add(x))]
def pos2key(pos):
    pos=pos.reshape(-1)
    key="{:08.4f}".format(pos[0])+'_'+"{:08.4f}".format(pos[1])
    return key
def get_angle(v1: Tensor, v2: Tensor):
    if v1.shape[1]==2:
        v1=F.pad(v1, (0, 1))
    if v2.shape[1]==2:
        v2= F.pad(v2, (0, 1))
    return torch.atan2( 
        torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
class GaussianSmearing(torch.nn.Module):
    def __init__(self, start=-PI, stop=PI, num_gaussians=12):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians) 
        self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2  
        self.register_buffer("offset", offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

def triplets(edge_index, num_nodes):
    row, col = edge_index  

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=row, col=col, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[col] 
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    idx_i = row.repeat_interleave(num_triplets)
    idx_j = col.repeat_interleave(num_triplets)
    edx_1st = value.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()
    edx_2nd = adj_t_row.storage.value() 
    mask1 = (idx_i == idx_k) & (idx_j != idx_i)  
    mask2 = (idx_i == idx_j) & (idx_j != idx_k)  
    mask3 = (idx_j == idx_k) & (idx_i != idx_k)  
    mask = ~(mask1 | mask2 | mask3) 
    idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
    
    num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]

    return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd


if __name__ == '__main__':
    1