danieldk HF Staff commited on
Commit
2913ead
·
1 Parent(s): 290dd18
build/torch-universal/kernels_test/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+
3
+ __all__ = ["layers"]
build/torch-universal/kernels_test/layers.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LinearImplicitBackward(nn.Module):
6
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
7
+ return F.linear(input, self.weight, self.bias)
8
+
9
+
10
+ class LinearBackward(nn.Module):
11
+ has_backward = True
12
+
13
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
14
+ return F.linear(input, self.weight, self.bias)
15
+
16
+
17
+ class LinearNoBackward(nn.Module):
18
+ has_backward = False
19
+
20
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
21
+ return F.linear(input, self.weight, self.bias)
22
+
23
+
24
+ __all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"]