File size: 7,729 Bytes
c7995e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import pandas as pd
import sys
from torch_geometric.nn import knn_graph
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import add_self_loops
from torch_geometric.data.collate import collate
from torch_geometric.data.separate import separate
import pickle
import time

from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from typing import Any
def mycollate(data_list):
    r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
    to the internal storage format of
    :class:`~torch_geometric.data.InMemoryDataset`."""
    if len(data_list) == 1:
        return data_list[0], None
    data, slices, _ = collate(
        data_list[0].__class__,
        data_list=data_list,
        increment=False,
        add_batch=False,
    )
    return data, slices
def myseparate(cls, batch: BaseData, idx: int, slice_dict: Any) -> BaseData:
    data = cls().stores_as(batch)
    # We iterate over each storage object and recursively separate all its attributes:
    for batch_store, data_store in zip(batch.stores, data.stores):
        attrs = set(batch_store.keys())
        for attr in attrs:
            slices = slice_dict[attr]
            data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
                                         batch, batch_store)
    return data
def _separate(
    key: str,
    value: Any,
    idx: int,
    slices: Any,
    batch: BaseData,
    store: BaseStorage,
) :
        # Narrow a `torch.Tensor` based on `slices`.
        key = str(key)
        cat_dim = batch.__cat_dim__(key, value, store)
        start, end = int(slices[idx]), int(slices[idx + 1])
        value = value.narrow(cat_dim or 0, start, end - start)
        return value

def load_point(datasetname="south",k=5,small=[False,50,100]):
    """ 
    load point and build graph pairs
    """
    print("loading")
    time1=time.time()
    if small[0]:
        print("small south dataset k=5")
        datasetname="south"
        k=5
        filename=os.path.join("data",datasetname,datasetname+f'_{k}.pt')
        [data_graphs1,slices_graphs1,data_graphs2,slices_graphs2]=torch.load(filename)
        flattened_list_graphs1 = [myseparate(cls=data_graphs1.__class__, batch=data_graphs1,idx=i,slice_dict=slices_graphs1) for i in range(small[1]*2)]
        flattened_list_graphs2 = [myseparate(cls=data_graphs2.__class__, batch=data_graphs2,idx=i,slice_dict=slices_graphs2) for i in range(small[2]*2)]
        unflattened_list_graphs1= [flattened_list_graphs1[n:n+2] for n in range(0, len(flattened_list_graphs1), 2)]
        unflattened_list_graphs2= [flattened_list_graphs2[n:n+2] for n in range(0, len(flattened_list_graphs2), 2)]
        print(f"Load data used {time.time()-time1:.1f} seconds")
        return unflattened_list_graphs1,unflattened_list_graphs2
    return process(datasetname,k)
def process(datasetname="south",k=5):
    time1=time.time()
    """ 
    build graph pairs
    """
    point_path= os.path.join("data",datasetname,datasetname+".pkl")
    with open(point_path, 'rb') as f:
        data = pickle.load(f)
    graphs1=[]
    graphs2=[]
    for day in data:
        day_d1=day[0]
        day_d2=day[1]
        assert(len(day_d1)<len(day_d2))
        pos1=day_d1[:,-2:]
        edge_index1=knn_graph(pos1,k=k)
        pos2=day_d2[:,-2:]
        edge_index2=knn_graph(pos2,k=k)
        """ 
        iterately mask point in day_d1, the high fidelity data, to build high fidelity graphs, which share the same structure
        """
        for i in range(day_d1.shape[0]):
            day_d1_copy=day_d1.clone().detach()
            target=day_d1[i,0]
            day_d1_copy[i,0]=0
            target_index=torch.tensor(i,dtype=torch.long)
            is_source = torch.ones(day_d1.shape[0] ,dtype=torch.bool)
            is_source[i]=False
            graph1=Data(x=day_d1_copy,pos=pos1,edge_index=edge_index1,target=target[None],target_index=target_index[None],is_source=is_source,datasource=torch.tensor(0,dtype=torch.long)[None])
            """ 
            build pairing low fidelity graphs, which add the masked point in day_d1, so structure is changing
            """            
            day_plus2=torch.cat([day_d1_copy[i][None,:],day_d2])
            pos_plus2=day_plus2[:,-2:]
            edge_index_plus2=knn_graph(pos_plus2,k=k)
            is_source = torch.ones(day_d2.shape[0]+1 ,dtype=torch.bool)
            is_source[0]=False
            graph2=Data(x=day_plus2,pos=pos_plus2,edge_index=edge_index_plus2,target=target[None],target_index=torch.tensor(0,dtype=torch.long)[None],is_source=is_source,datasource=torch.tensor(0,dtype=torch.long)[None])
            graphs1.append([graph1,graph2])
        """ 
        iterately mask point in day_d2, the low fidelity data, to build low fidelity graphs, which share the same structure
        """
        for i in range(day_d2.shape[0]):
            day_d2_copy=day_d2.clone().detach()
            target=day_d2[i,0]
            day_d2_copy[i,0]=0
            target_index=torch.tensor(i,dtype=torch.long)
            is_source = torch.ones(day_d2.shape[0] ,dtype=torch.bool)
            is_source[i]=False
            graph2=Data(x=day_d2_copy,pos=pos2,edge_index=edge_index2,target=target[None],target_index=target_index[None],is_source=is_source,datasource=torch.tensor(1,dtype=torch.long)[None])
            """ 
            build pairing high fidelity graphs, which add the masked point in day_d2, so structure is changing
            """            
            day_plus1=torch.cat([day_d2_copy[i][None,:],day_d1])
            pos_plus1=day_plus1[:,-2:]
            edge_index_plus1=knn_graph(pos_plus1,k=k)
            is_source = torch.ones(day_d1.shape[0]+1 ,dtype=torch.bool)
            is_source[0]=False
            graph1=Data(x=day_plus1,pos=pos_plus1,edge_index=edge_index_plus1,target=target[None],target_index=torch.tensor(0,dtype=torch.long)[None],is_source=is_source,datasource=torch.tensor(1,dtype=torch.long)[None])
            graphs2.append([graph1,graph2])
    np.random.shuffle(graphs1)
    np.random.shuffle(graphs2)
    return [graphs1,graphs2]

class MergeNeighborDataset(torch.utils.data.Dataset):
    """ Customized dataset for each domain"""
    def __init__(self,X):
        self.X = X                           # set data
    def __len__(self):
        return len(self.X)                   # return length
    def __getitem__(self, idx):
        return self.X[idx] 
def kneighbor_point(datasetname="south",k=1,daily=False):
    """ 
    build k neighbor pairing
    """
    ranking_path= os.path.join("data",datasetname,datasetname+"_ranking.pkl")
    with open(ranking_path, 'rb') as f:
        rankings = pickle.load(f)
    point_path= os.path.join("data",datasetname,datasetname+".pkl")
    with open(point_path, 'rb') as f:
        days = pickle.load(f)
    samples=[]
    for i in range(len(days)):
        day_d1=days[i][0]
        day_d2=days[i][1]
        ranking=rankings[i]
        """ 
        iterately get point in day_d1, the high fidelity data, to build samples
        """
        sample1 = []
        for j in range(day_d1.shape[0]):
            point1=day_d1[j]
            point1_neighbors=day_d2[ranking[j,:k]]
            point1_neighbor=torch.mean(point1_neighbors,axis=0)
            sample1.append([point1,point1_neighbor])
        if daily:
            samples.append(sample1)
        else:
            samples.extend(sample1)
    if not daily:
        return [samples]
    return samples

if __name__ == '__main__':
    1