darabos commited on
Commit
1db4274
·
1 Parent(s): 09b8d25

High level boxes for Neural ODE + GNN demo.

Browse files
examples/ODE-GNN ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Input: embedding 1 Graph conv 1",
5
+ "source": "Input: embedding 1",
6
+ "sourceHandle": "x",
7
+ "target": "Graph conv 1",
8
+ "targetHandle": "x"
9
+ },
10
+ {
11
+ "id": "Input: graph edges 1 Graph conv 1",
12
+ "source": "Input: graph edges 1",
13
+ "sourceHandle": "edges",
14
+ "target": "Graph conv 1",
15
+ "targetHandle": "edges"
16
+ },
17
+ {
18
+ "id": "Graph conv 1 Activation 1",
19
+ "source": "Graph conv 1",
20
+ "sourceHandle": "x",
21
+ "target": "Activation 1",
22
+ "targetHandle": "x"
23
+ },
24
+ {
25
+ "id": "Activation 1 Repeat 1",
26
+ "source": "Activation 1",
27
+ "sourceHandle": "x",
28
+ "target": "Repeat 1",
29
+ "targetHandle": "input"
30
+ },
31
+ {
32
+ "id": "Repeat 1 Graph conv 1",
33
+ "source": "Repeat 1",
34
+ "sourceHandle": "output",
35
+ "target": "Graph conv 1",
36
+ "targetHandle": "x"
37
+ },
38
+ {
39
+ "id": "Input: sequential 1 LSTM 1",
40
+ "source": "Input: sequential 1",
41
+ "sourceHandle": "y",
42
+ "target": "LSTM 1",
43
+ "targetHandle": "x"
44
+ },
45
+ {
46
+ "id": "Input: zeros 1 LSTM 1",
47
+ "source": "Input: zeros 1",
48
+ "sourceHandle": "x",
49
+ "target": "LSTM 1",
50
+ "targetHandle": "h"
51
+ },
52
+ {
53
+ "id": "Recurrent chain 1 LSTM 1",
54
+ "source": "Recurrent chain 1",
55
+ "sourceHandle": "output",
56
+ "target": "LSTM 1",
57
+ "targetHandle": "h"
58
+ },
59
+ {
60
+ "id": "LSTM 1 Recurrent chain 1",
61
+ "source": "LSTM 1",
62
+ "sourceHandle": "h",
63
+ "target": "Recurrent chain 1",
64
+ "targetHandle": "input"
65
+ },
66
+ {
67
+ "id": "Activation 1 Concatenate 1",
68
+ "source": "Activation 1",
69
+ "sourceHandle": "x",
70
+ "target": "Concatenate 1",
71
+ "targetHandle": "a"
72
+ },
73
+ {
74
+ "id": "LSTM 1 Concatenate 1",
75
+ "source": "LSTM 1",
76
+ "sourceHandle": "x",
77
+ "target": "Concatenate 1",
78
+ "targetHandle": "b"
79
+ },
80
+ {
81
+ "id": "Concatenate 1 Neural ODE 1",
82
+ "source": "Concatenate 1",
83
+ "sourceHandle": "x",
84
+ "target": "Neural ODE 1",
85
+ "targetHandle": "x"
86
+ },
87
+ {
88
+ "id": "Neural ODE 1 MSE loss 1",
89
+ "source": "Neural ODE 1",
90
+ "sourceHandle": "x",
91
+ "target": "MSE loss 1",
92
+ "targetHandle": "x"
93
+ },
94
+ {
95
+ "id": "Input: label 1 MSE loss 1",
96
+ "source": "Input: label 1",
97
+ "sourceHandle": "y",
98
+ "target": "MSE loss 1",
99
+ "targetHandle": "y"
100
+ },
101
+ {
102
+ "id": "MSE loss 1 Optimizer 1",
103
+ "source": "MSE loss 1",
104
+ "sourceHandle": "loss",
105
+ "target": "Optimizer 1",
106
+ "targetHandle": "loss"
107
+ }
108
+ ],
109
+ "env": "PyTorch model",
110
+ "nodes": [
111
+ {
112
+ "data": {
113
+ "display": null,
114
+ "error": null,
115
+ "meta": {
116
+ "inputs": {
117
+ "edges": {
118
+ "name": "edges",
119
+ "position": "bottom",
120
+ "type": {
121
+ "type": "tensor"
122
+ }
123
+ },
124
+ "x": {
125
+ "name": "x",
126
+ "position": "bottom",
127
+ "type": {
128
+ "type": "tensor"
129
+ }
130
+ }
131
+ },
132
+ "name": "Graph conv",
133
+ "outputs": {
134
+ "x": {
135
+ "name": "x",
136
+ "position": "top",
137
+ "type": {
138
+ "type": "tensor"
139
+ }
140
+ }
141
+ },
142
+ "params": {
143
+ "type": {
144
+ "default": "1",
145
+ "name": "type",
146
+ "type": {
147
+ "enum": [
148
+ "GCNConv",
149
+ "GATConv",
150
+ "GATv2Conv",
151
+ "SAGEConv"
152
+ ]
153
+ }
154
+ }
155
+ },
156
+ "type": "basic"
157
+ },
158
+ "params": {
159
+ "type": 1.0
160
+ },
161
+ "status": "planned",
162
+ "title": "Graph conv"
163
+ },
164
+ "dragHandle": ".bg-primary",
165
+ "height": 200.0,
166
+ "id": "Graph conv 1",
167
+ "position": {
168
+ "x": 360.0,
169
+ "y": 195.0
170
+ },
171
+ "type": "basic",
172
+ "width": 200.0
173
+ },
174
+ {
175
+ "data": {
176
+ "display": null,
177
+ "error": null,
178
+ "meta": {
179
+ "inputs": {
180
+ "input": {
181
+ "name": "input",
182
+ "position": "top",
183
+ "type": {
184
+ "type": "tensor"
185
+ }
186
+ }
187
+ },
188
+ "name": "Repeat",
189
+ "outputs": {
190
+ "output": {
191
+ "name": "output",
192
+ "position": "bottom",
193
+ "type": {
194
+ "type": "tensor"
195
+ }
196
+ }
197
+ },
198
+ "params": {
199
+ "times": {
200
+ "default": 1.0,
201
+ "name": "times",
202
+ "type": {
203
+ "type": "<class 'int'>"
204
+ }
205
+ }
206
+ },
207
+ "type": "basic"
208
+ },
209
+ "params": {
210
+ "times": 1.0
211
+ },
212
+ "status": "planned",
213
+ "title": "Repeat"
214
+ },
215
+ "dragHandle": ".bg-primary",
216
+ "height": 200.0,
217
+ "id": "Repeat 1",
218
+ "position": {
219
+ "x": -94.15168677219138,
220
+ "y": 14.525356969883305
221
+ },
222
+ "type": "basic",
223
+ "width": 200.0
224
+ },
225
+ {
226
+ "data": {
227
+ "__execution_delay": null,
228
+ "collapsed": true,
229
+ "display": null,
230
+ "error": null,
231
+ "meta": {
232
+ "inputs": {
233
+ "a": {
234
+ "name": "a",
235
+ "position": "bottom",
236
+ "type": {
237
+ "type": "tensor"
238
+ }
239
+ },
240
+ "b": {
241
+ "name": "b",
242
+ "position": "bottom",
243
+ "type": {
244
+ "type": "tensor"
245
+ }
246
+ }
247
+ },
248
+ "name": "Concatenate",
249
+ "outputs": {
250
+ "x": {
251
+ "name": "x",
252
+ "position": "top",
253
+ "type": {
254
+ "type": "tensor"
255
+ }
256
+ }
257
+ },
258
+ "params": {},
259
+ "type": "basic"
260
+ },
261
+ "params": {},
262
+ "status": "planned",
263
+ "title": "Concatenate"
264
+ },
265
+ "dragHandle": ".bg-primary",
266
+ "height": 200.0,
267
+ "id": "Concatenate 1",
268
+ "position": {
269
+ "x": 477.88148637482334,
270
+ "y": -372.62774030487003
271
+ },
272
+ "type": "basic",
273
+ "width": 200.0
274
+ },
275
+ {
276
+ "data": {
277
+ "__execution_delay": null,
278
+ "collapsed": true,
279
+ "display": null,
280
+ "error": null,
281
+ "meta": {
282
+ "inputs": {},
283
+ "name": "Input: graph edges",
284
+ "outputs": {
285
+ "edges": {
286
+ "name": "edges",
287
+ "position": "top",
288
+ "type": {
289
+ "type": "tensor"
290
+ }
291
+ }
292
+ },
293
+ "params": {},
294
+ "type": "basic"
295
+ },
296
+ "params": {},
297
+ "status": "planned",
298
+ "title": "Input: graph edges"
299
+ },
300
+ "dragHandle": ".bg-primary",
301
+ "height": 200.0,
302
+ "id": "Input: graph edges 1",
303
+ "position": {
304
+ "x": 515.6535517374441,
305
+ "y": 545.4709559884296
306
+ },
307
+ "type": "basic",
308
+ "width": 200.0
309
+ },
310
+ {
311
+ "data": {
312
+ "__execution_delay": null,
313
+ "collapsed": true,
314
+ "display": null,
315
+ "error": null,
316
+ "meta": {
317
+ "inputs": {},
318
+ "name": "Input: embedding",
319
+ "outputs": {
320
+ "x": {
321
+ "name": "x",
322
+ "position": "top",
323
+ "type": {
324
+ "type": "tensor"
325
+ }
326
+ }
327
+ },
328
+ "params": {},
329
+ "type": "basic"
330
+ },
331
+ "params": {},
332
+ "status": "planned",
333
+ "title": "Input: embedding"
334
+ },
335
+ "dragHandle": ".bg-primary",
336
+ "height": 200.0,
337
+ "id": "Input: embedding 1",
338
+ "position": {
339
+ "x": 246.6527948448857,
340
+ "y": 551.6313504198322
341
+ },
342
+ "type": "basic",
343
+ "width": 200.0
344
+ },
345
+ {
346
+ "data": {
347
+ "display": null,
348
+ "error": null,
349
+ "meta": {
350
+ "inputs": {
351
+ "x": {
352
+ "name": "x",
353
+ "position": "bottom",
354
+ "type": {
355
+ "type": "tensor"
356
+ }
357
+ }
358
+ },
359
+ "name": "Activation",
360
+ "outputs": {
361
+ "x": {
362
+ "name": "x",
363
+ "position": "top",
364
+ "type": {
365
+ "type": "tensor"
366
+ }
367
+ }
368
+ },
369
+ "params": {
370
+ "type": {
371
+ "default": "1",
372
+ "name": "type",
373
+ "type": {
374
+ "enum": [
375
+ "ReLU",
376
+ "LeakyReLU",
377
+ "Tanh",
378
+ "Mish"
379
+ ]
380
+ }
381
+ }
382
+ },
383
+ "type": "basic"
384
+ },
385
+ "params": {
386
+ "type": 1.0
387
+ },
388
+ "status": "planned",
389
+ "title": "Activation"
390
+ },
391
+ "dragHandle": ".bg-primary",
392
+ "height": 200.0,
393
+ "id": "Activation 1",
394
+ "position": {
395
+ "x": 354.3731834561054,
396
+ "y": -73.74768512965228
397
+ },
398
+ "type": "basic",
399
+ "width": 200.0
400
+ },
401
+ {
402
+ "data": {
403
+ "__execution_delay": null,
404
+ "collapsed": true,
405
+ "display": null,
406
+ "error": null,
407
+ "meta": {
408
+ "inputs": {
409
+ "h": {
410
+ "name": "h",
411
+ "position": "bottom",
412
+ "type": {
413
+ "type": "tensor"
414
+ }
415
+ },
416
+ "x": {
417
+ "name": "x",
418
+ "position": "bottom",
419
+ "type": {
420
+ "type": "tensor"
421
+ }
422
+ }
423
+ },
424
+ "name": "LSTM",
425
+ "outputs": {
426
+ "h": {
427
+ "name": "h",
428
+ "position": "top",
429
+ "type": {
430
+ "type": "tensor"
431
+ }
432
+ },
433
+ "x": {
434
+ "name": "x",
435
+ "position": "top",
436
+ "type": {
437
+ "type": "tensor"
438
+ }
439
+ }
440
+ },
441
+ "params": {},
442
+ "type": "basic"
443
+ },
444
+ "params": {},
445
+ "status": "planned",
446
+ "title": "LSTM"
447
+ },
448
+ "dragHandle": ".bg-primary",
449
+ "height": 200.0,
450
+ "id": "LSTM 1",
451
+ "position": {
452
+ "x": 960.0,
453
+ "y": 135.0
454
+ },
455
+ "type": "basic",
456
+ "width": 200.0
457
+ },
458
+ {
459
+ "data": {
460
+ "__execution_delay": null,
461
+ "collapsed": true,
462
+ "display": null,
463
+ "error": null,
464
+ "meta": {
465
+ "inputs": {},
466
+ "name": "Input: sequential",
467
+ "outputs": {
468
+ "y": {
469
+ "name": "y",
470
+ "position": "top",
471
+ "type": {
472
+ "type": "tensor"
473
+ }
474
+ }
475
+ },
476
+ "params": {},
477
+ "type": "basic"
478
+ },
479
+ "params": {},
480
+ "status": "planned",
481
+ "title": "Input: sequential"
482
+ },
483
+ "dragHandle": ".bg-primary",
484
+ "height": 200.0,
485
+ "id": "Input: sequential 1",
486
+ "position": {
487
+ "x": 1005.0,
488
+ "y": 510.0
489
+ },
490
+ "type": "basic",
491
+ "width": 200.0
492
+ },
493
+ {
494
+ "data": {
495
+ "__execution_delay": null,
496
+ "collapsed": true,
497
+ "display": null,
498
+ "error": null,
499
+ "meta": {
500
+ "inputs": {},
501
+ "name": "Input: zeros",
502
+ "outputs": {
503
+ "x": {
504
+ "name": "x",
505
+ "position": "top",
506
+ "type": {
507
+ "type": "tensor"
508
+ }
509
+ }
510
+ },
511
+ "params": {},
512
+ "type": "basic"
513
+ },
514
+ "params": {},
515
+ "status": "planned",
516
+ "title": "Input: zeros"
517
+ },
518
+ "dragHandle": ".bg-primary",
519
+ "height": 200.0,
520
+ "id": "Input: zeros 1",
521
+ "position": {
522
+ "x": 1290.0,
523
+ "y": 405.0
524
+ },
525
+ "type": "basic",
526
+ "width": 200.0
527
+ },
528
+ {
529
+ "data": {
530
+ "__execution_delay": null,
531
+ "collapsed": true,
532
+ "display": null,
533
+ "error": null,
534
+ "meta": {
535
+ "inputs": {
536
+ "input": {
537
+ "name": "input",
538
+ "position": "top",
539
+ "type": {
540
+ "type": "tensor"
541
+ }
542
+ }
543
+ },
544
+ "name": "Recurrent chain",
545
+ "outputs": {
546
+ "output": {
547
+ "name": "output",
548
+ "position": "bottom",
549
+ "type": {
550
+ "type": "tensor"
551
+ }
552
+ }
553
+ },
554
+ "params": {},
555
+ "type": "basic"
556
+ },
557
+ "params": {},
558
+ "status": "planned",
559
+ "title": "Recurrent chain"
560
+ },
561
+ "dragHandle": ".bg-primary",
562
+ "height": 200.0,
563
+ "id": "Recurrent chain 1",
564
+ "position": {
565
+ "x": 1224.6603040746108,
566
+ "y": 135.44839862151363
567
+ },
568
+ "type": "basic",
569
+ "width": 200.0
570
+ },
571
+ {
572
+ "data": {
573
+ "__execution_delay": null,
574
+ "collapsed": true,
575
+ "display": null,
576
+ "error": null,
577
+ "meta": {
578
+ "inputs": {
579
+ "x": {
580
+ "name": "x",
581
+ "position": "bottom",
582
+ "type": {
583
+ "type": "tensor"
584
+ }
585
+ }
586
+ },
587
+ "name": "Neural ODE",
588
+ "outputs": {
589
+ "x": {
590
+ "name": "x",
591
+ "position": "top",
592
+ "type": {
593
+ "type": "tensor"
594
+ }
595
+ }
596
+ },
597
+ "params": {},
598
+ "type": "basic"
599
+ },
600
+ "params": {},
601
+ "status": "planned",
602
+ "title": "Neural ODE"
603
+ },
604
+ "dragHandle": ".bg-primary",
605
+ "height": 200.0,
606
+ "id": "Neural ODE 1",
607
+ "position": {
608
+ "x": 475.11029083619064,
609
+ "y": -633.1862788850791
610
+ },
611
+ "type": "basic",
612
+ "width": 200.0
613
+ },
614
+ {
615
+ "data": {
616
+ "__execution_delay": null,
617
+ "collapsed": true,
618
+ "display": null,
619
+ "error": null,
620
+ "meta": {
621
+ "inputs": {
622
+ "x": {
623
+ "name": "x",
624
+ "position": "bottom",
625
+ "type": {
626
+ "type": "tensor"
627
+ }
628
+ },
629
+ "y": {
630
+ "name": "y",
631
+ "position": "bottom",
632
+ "type": {
633
+ "type": "tensor"
634
+ }
635
+ }
636
+ },
637
+ "name": "MSE loss",
638
+ "outputs": {
639
+ "loss": {
640
+ "name": "loss",
641
+ "position": "top",
642
+ "type": {
643
+ "type": "tensor"
644
+ }
645
+ }
646
+ },
647
+ "params": {},
648
+ "position": {
649
+ "x": 783.0,
650
+ "y": 101.0
651
+ },
652
+ "type": "basic"
653
+ },
654
+ "params": {},
655
+ "status": "planned",
656
+ "title": "MSE loss"
657
+ },
658
+ "dragHandle": ".bg-primary",
659
+ "height": 200.0,
660
+ "id": "MSE loss 1",
661
+ "position": {
662
+ "x": 915.0,
663
+ "y": -900.0
664
+ },
665
+ "type": "basic",
666
+ "width": 200.0
667
+ },
668
+ {
669
+ "data": {
670
+ "__execution_delay": null,
671
+ "collapsed": true,
672
+ "display": null,
673
+ "error": null,
674
+ "meta": {
675
+ "inputs": {},
676
+ "name": "Input: label",
677
+ "outputs": {
678
+ "y": {
679
+ "name": "y",
680
+ "position": "top",
681
+ "type": {
682
+ "type": "tensor"
683
+ }
684
+ }
685
+ },
686
+ "params": {},
687
+ "position": {
688
+ "x": 893.0,
689
+ "y": 369.0
690
+ },
691
+ "type": "basic"
692
+ },
693
+ "params": {},
694
+ "status": "planned",
695
+ "title": "Input: label"
696
+ },
697
+ "dragHandle": ".bg-primary",
698
+ "height": 200.0,
699
+ "id": "Input: label 1",
700
+ "position": {
701
+ "x": 1095.0,
702
+ "y": -450.0
703
+ },
704
+ "type": "basic",
705
+ "width": 200.0
706
+ },
707
+ {
708
+ "data": {
709
+ "display": null,
710
+ "error": null,
711
+ "meta": {
712
+ "inputs": {
713
+ "loss": {
714
+ "name": "loss",
715
+ "position": "bottom",
716
+ "type": {
717
+ "type": "tensor"
718
+ }
719
+ }
720
+ },
721
+ "name": "Optimizer",
722
+ "outputs": {},
723
+ "params": {
724
+ "lr": {
725
+ "default": 0.001,
726
+ "name": "lr",
727
+ "type": {
728
+ "type": "<class 'float'>"
729
+ }
730
+ },
731
+ "type": {
732
+ "default": 1.0,
733
+ "name": "type",
734
+ "type": {
735
+ "enum": [
736
+ "AdamW",
737
+ "Adafactor",
738
+ "Adagrad",
739
+ "SGD",
740
+ "Lion",
741
+ "Paged AdamW",
742
+ "Galore AdamW"
743
+ ]
744
+ }
745
+ }
746
+ },
747
+ "position": {
748
+ "x": 986.0,
749
+ "y": 165.0
750
+ },
751
+ "type": "basic"
752
+ },
753
+ "params": {
754
+ "lr": 0.001,
755
+ "type": 1.0
756
+ },
757
+ "status": "planned",
758
+ "title": "Optimizer"
759
+ },
760
+ "dragHandle": ".bg-primary",
761
+ "height": 247.0,
762
+ "id": "Optimizer 1",
763
+ "position": {
764
+ "x": 915.3430278730226,
765
+ "y": -1268.0577550022126
766
+ },
767
+ "type": "basic",
768
+ "width": 190.0
769
+ }
770
+ ]
771
+ }
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -29,7 +29,11 @@ reg("Input: graph edges", outputs=["edges"])
29
  reg("Input: label", outputs=["y"])
30
  reg("Input: positive sample", outputs=["x_pos"])
31
  reg("Input: negative sample", outputs=["x_neg"])
 
 
32
 
 
 
33
  reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
34
  reg("LayerNorm", inputs=["x"])
35
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
@@ -82,6 +86,14 @@ ops.register_passive_op(
82
  params=[ops.Parameter.basic("times", 1, int)],
83
  )
84
 
 
 
 
 
 
 
 
 
85
 
86
  def build_model(ws: workspace.Workspace, inputs: dict):
87
  """Builds the model described in the workspace."""
 
29
  reg("Input: label", outputs=["y"])
30
  reg("Input: positive sample", outputs=["x_pos"])
31
  reg("Input: negative sample", outputs=["x_neg"])
32
+ reg("Input: sequential", outputs=["y"])
33
+ reg("Input: zeros", outputs=["x"])
34
 
35
+ reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
36
+ reg("Neural ODE", inputs=["x"])
37
  reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
38
  reg("LayerNorm", inputs=["x"])
39
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
 
86
  params=[ops.Parameter.basic("times", 1, int)],
87
  )
88
 
89
+ ops.register_passive_op(
90
+ ENV,
91
+ "Recurrent chain",
92
+ inputs=[ops.Input(name="input", position="top", type="tensor")],
93
+ outputs=[ops.Output(name="output", position="bottom", type="tensor")],
94
+ params=[],
95
+ )
96
+
97
 
98
  def build_model(ws: workspace.Workspace, inputs: dict):
99
  """Builds the model described in the workspace."""