osbm commited on
Commit
38c83a0
·
1 Parent(s): ace5f83

Update layers.py

Browse files
Files changed (1) hide show
  1. layers.py +8 -178
layers.py CHANGED
@@ -1,13 +1,10 @@
1
  import torch
2
  import torch.nn as nn
3
- from torch.nn.modules.module import Module
4
  from torch.nn import functional as F
5
- from torch.nn import Embedding, ModuleList
6
- from torch_geometric.nn import PNAConv, global_add_pool, Set2Set, GraphMultisetTransformer
7
  import math
8
 
9
  class MLP(nn.Module):
10
- def __init__(self, act, in_feat, hid_feat=None, out_feat=None,
11
  dropout=0.):
12
  super().__init__()
13
  if not hid_feat:
@@ -26,7 +23,7 @@ class MLP(nn.Module):
26
  return self.droprateout(x)
27
 
28
  class Attention_new(nn.Module):
29
- def __init__(self, dim, heads, act, attention_dropout=0., proj_dropout=0.):
30
  super().__init__()
31
  assert dim % heads == 0
32
  self.heads = heads
@@ -79,15 +76,15 @@ class Attention_new(nn.Module):
79
  return node, edge
80
 
81
  class Encoder_Block(nn.Module):
82
- def __init__(self, dim, heads,act, mlp_ratio=4, drop_rate=0., ):
83
  super().__init__()
84
  self.ln1 = nn.LayerNorm(dim)
85
 
86
- self.attn = Attention_new(dim, heads, act, drop_rate, drop_rate)
87
  self.ln3 = nn.LayerNorm(dim)
88
  self.ln4 = nn.LayerNorm(dim)
89
- self.mlp = MLP(act,dim,dim*mlp_ratio, dim, dropout=drop_rate)
90
- self.mlp2 = MLP(act,dim,dim*mlp_ratio, dim, dropout=drop_rate)
91
  self.ln5 = nn.LayerNorm(dim)
92
  self.ln6 = nn.LayerNorm(dim)
93
 
@@ -199,7 +196,7 @@ class Decoder_Block(nn.Module):
199
  self.ln1_mx = nn.LayerNorm(dim)
200
  self.ln1_px = nn.LayerNorm(dim)
201
 
202
- self.attn2 = Attention_new(dim, heads, drop_rate, drop_rate)
203
 
204
  self.ln2_pa = nn.LayerNorm(dim)
205
  self.ln2_px = nn.LayerNorm(dim)
@@ -265,171 +262,4 @@ class TransformerDecoder(nn.Module):
265
  mol_annot, prot_annot, mol_adj, prot_adj = Decoder_Block(mol_annot, prot_annot, mol_adj, prot_adj)
266
 
267
  return mol_annot, prot_annot,mol_adj, prot_adj
268
-
269
-
270
-
271
- """class PNA(torch.nn.Module):
272
- def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
273
- super(PNA,self).__init__()
274
-
275
- self.node_emb = Embedding(30, pna_in_ch)
276
- self.edge_emb = Embedding(30, edge_dim)
277
- degree = deg
278
- aggregators = agg.split(",") #["max"] # 'sum', 'min', 'max' 'std', 'var' 'mean', ## buraları değiştirerek bak.
279
- scalers = sca.split(",") # ['amplification', 'attenuation'] # 'amplification', 'attenuation' , 'linear', 'inverse_linear, 'identity'
280
- self.graph_add = graph_add
281
- self.convs = ModuleList()
282
- self.batch_norms = ModuleList()
283
-
284
- for _ in range(pna_layer_num): ##### layer sayısını hyperparameter olarak ayarla??
285
- conv = PNAConv(in_channels=pna_in_ch, out_channels=pna_out_ch,
286
- aggregators=aggregators, scalers=scalers, deg=degree,
287
- edge_dim=edge_dim, towers=towers, pre_layers=pre_lay, post_layers=post_lay, ## tower sayısını değiştirerek dene, default - 1
288
- divide_input=True)
289
- self.convs.append(conv)
290
- self.batch_norms.append(nn.LayerNorm(pna_out_ch))
291
-
292
- #self.graph_multitrans = GraphMultisetTransformer(in_channels=pna_out_ch, hidden_channels= 200,
293
- #out_channels= pna_out_ch, layer_norm = True)
294
- if self.graph_add == "set2set":
295
- self.s2s = Set2Set(in_channels=pna_out_ch, processing_steps=1, num_layers=1)
296
-
297
- if self.graph_add == "set2set":
298
- pna_out_ch = pna_out_ch*2
299
- self.mlp = nn.Sequential(nn.Linear(pna_out_ch,pna_out_ch), nn.Tanh(), nn.Linear(pna_out_ch,25), nn.Tanh(),nn.Linear(25,1))
300
-
301
- def forward(self, x, edge_index, edge_attr, batch):
302
-
303
- x = self.node_emb(x.squeeze())
304
-
305
- edge_attr = self.edge_emb(edge_attr)
306
-
307
- for conv, batch_norm in zip(self.convs, self.batch_norms):
308
- x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))
309
-
310
- if self.graph_add == "global_add":
311
- x = global_add_pool(x, batch.squeeze())
312
-
313
- elif self.graph_add == "set2set":
314
-
315
- x = self.s2s(x, batch.squeeze())
316
- #elif self.graph_add == "graph_multitrans":
317
- #x = self.graph_multitrans(x,batch.squeeze(),edge_index)
318
- x = self.mlp(x)
319
-
320
- return x"""
321
-
322
-
323
-
324
-
325
- """class GraphConvolution(nn.Module):
326
-
327
- def __init__(self, in_features, out_feature_list, b_dim, dropout,gcn_depth):
328
- super(GraphConvolution, self).__init__()
329
- self.in_features = in_features
330
-
331
- self.gcn_depth = gcn_depth
332
-
333
- self.out_feature_list = out_feature_list
334
-
335
- self.gcn_in = nn.Sequential(nn.Linear(in_features,out_feature_list[0]),nn.Tanh(),
336
- nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
337
- nn.Linear(out_feature_list[0], out_feature_list[0]), nn.Dropout(dropout))
338
-
339
- self.gcn_convs = nn.ModuleList()
340
-
341
- for _ in range(gcn_depth):
342
-
343
- gcn_conv = nn.Sequential(nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
344
- nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
345
- nn.Linear(out_feature_list[0], out_feature_list[0]), nn.Dropout(dropout))
346
-
347
- self.gcn_convs.append(gcn_conv)
348
-
349
- self.gcn_out = nn.Sequential(nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
350
- nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
351
- nn.Linear(out_feature_list[0], out_feature_list[1]), nn.Dropout(dropout))
352
-
353
- self.dropout = nn.Dropout(dropout)
354
-
355
- def forward(self, input, adj, activation=None):
356
- # input : 16x9x9
357
- # adj : 16x4x9x9
358
- hidden = torch.stack([self.gcn_in(input) for _ in range(adj.size(1))], 1)
359
- hidden = torch.einsum('bijk,bikl->bijl', (adj, hidden))
360
-
361
- hidden = torch.sum(hidden, 1) + self.gcn_in(input)
362
- hidden = activation(hidden) if activation is not None else hidden
363
-
364
- for gcn_conv in self.gcn_convs:
365
- hidden1 = torch.stack([gcn_conv(hidden) for _ in range(adj.size(1))], 1)
366
- hidden1 = torch.einsum('bijk,bikl->bijl', (adj, hidden1))
367
- hidden = torch.sum(hidden1, 1) + gcn_conv(hidden)
368
- hidden = activation(hidden) if activation is not None else hidden
369
-
370
- output = torch.stack([self.gcn_out(hidden) for _ in range(adj.size(1))], 1)
371
- output = torch.einsum('bijk,bikl->bijl', (adj, output))
372
- output = torch.sum(output, 1) + self.gcn_out(hidden)
373
- output = activation(output) if activation is not None else output
374
-
375
-
376
- return output
377
-
378
-
379
- class GraphAggregation(Module):
380
-
381
- def __init__(self, in_features, out_features, m_dim, dropout):
382
- super(GraphAggregation, self).__init__()
383
- self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features), nn.Sigmoid())
384
- self.tanh_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features), nn.Tanh())
385
- self.dropout = nn.Dropout(dropout)
386
-
387
- def forward(self, input, activation):
388
- i = self.sigmoid_linear(input)
389
- j = self.tanh_linear(input)
390
- output = torch.sum(torch.mul(i,j), 1)
391
- output = activation(output) if activation is not None\
392
- else output
393
- output = self.dropout(output)
394
-
395
- return output"""
396
-
397
- """class Attention(nn.Module):
398
- def __init__(self, dim, heads=4, attention_dropout=0., proj_dropout=0.):
399
- super().__init__()
400
- self.heads = heads
401
- self.scale = 1./dim**0.5
402
- #self.scale = torch.div(1, torch.pow(dim, 0.5)) #1./torch.pow(dim, 0.5) #dim**0.5 torch.div(x, 0.5)
403
-
404
- self.qkv = nn.Linear(dim, dim*3, bias=False)
405
-
406
- self.attention_dropout = nn.Dropout(attention_dropout)
407
- self.out = nn.Sequential(
408
- nn.Linear(dim, dim),
409
- nn.Dropout(proj_dropout)
410
- )
411
- #self.noise_strength_1 = torch.nn.Parameter(torch.zeros([]))
412
-
413
- def forward(self, x):
414
- b, n, c = x.shape
415
-
416
- #x = x + torch.randn([x.size(0), x.size(1), 1], device=x.device) * self.noise_strength_1
417
-
418
- qkv = self.qkv(x).reshape(b, n, 3, self.heads, c//self.heads)
419
-
420
- q, k, v = qkv.permute(2, 0, 3, 1, 4)
421
-
422
- dot = (q @ k.transpose(-2, -1)) * self.scale
423
-
424
- attn = dot.softmax(dim=-1)
425
- attn = self.attention_dropout(attn)
426
-
427
-
428
- x = (attn @ v).transpose(1, 2).reshape(b, n, c)
429
-
430
- x = self.out(x)
431
-
432
- return x, attn"""
433
-
434
-
435
-
 
1
  import torch
2
  import torch.nn as nn
 
3
  from torch.nn import functional as F
 
 
4
  import math
5
 
6
  class MLP(nn.Module):
7
+ def __init__(self, in_feat, hid_feat=None, out_feat=None,
8
  dropout=0.):
9
  super().__init__()
10
  if not hid_feat:
 
23
  return self.droprateout(x)
24
 
25
  class Attention_new(nn.Module):
26
+ def __init__(self, dim, heads, attention_dropout=0.):
27
  super().__init__()
28
  assert dim % heads == 0
29
  self.heads = heads
 
76
  return node, edge
77
 
78
  class Encoder_Block(nn.Module):
79
+ def __init__(self, dim, heads,act, mlp_ratio=4, drop_rate=0.):
80
  super().__init__()
81
  self.ln1 = nn.LayerNorm(dim)
82
 
83
+ self.attn = Attention_new(dim, heads, drop_rate)
84
  self.ln3 = nn.LayerNorm(dim)
85
  self.ln4 = nn.LayerNorm(dim)
86
+ self.mlp = MLP(dim, dim*mlp_ratio, dim, dropout=drop_rate)
87
+ self.mlp2 = MLP(dim, dim*mlp_ratio, dim, dropout=drop_rate)
88
  self.ln5 = nn.LayerNorm(dim)
89
  self.ln6 = nn.LayerNorm(dim)
90
 
 
196
  self.ln1_mx = nn.LayerNorm(dim)
197
  self.ln1_px = nn.LayerNorm(dim)
198
 
199
+ self.attn2 = Attention_new(dim, heads, drop_rate)
200
 
201
  self.ln2_pa = nn.LayerNorm(dim)
202
  self.ln2_px = nn.LayerNorm(dim)
 
262
  mol_annot, prot_annot, mol_adj, prot_adj = Decoder_Block(mol_annot, prot_annot, mol_adj, prot_adj)
263
 
264
  return mol_annot, prot_annot,mol_adj, prot_adj
265
+