Spaces:
Runtime error
Runtime error
Update TabPFN/layer.py
Browse files- TabPFN/layer.py +10 -2
TabPFN/layer.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
| 1 |
from functools import partial
|
| 2 |
|
| 3 |
from torch import nn
|
| 4 |
-
from torch.nn.modules.transformer import
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from torch.utils.checkpoint import checkpoint
|
| 8 |
|
|
|
|
| 1 |
from functools import partial
|
| 2 |
|
| 3 |
from torch import nn
|
| 4 |
+
from torch.nn.modules.transformer import (
|
| 5 |
+
_get_activation_fn,
|
| 6 |
+
Module,
|
| 7 |
+
Tensor,
|
| 8 |
+
Optional,
|
| 9 |
+
MultiheadAttention,
|
| 10 |
+
Linear,
|
| 11 |
+
Dropout,
|
| 12 |
+
LayerNorm,
|
| 13 |
+
)
|
| 14 |
|
| 15 |
from torch.utils.checkpoint import checkpoint
|
| 16 |
|