File size: 465 Bytes
c8b5862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch 
import torch.nn as nun

class SimpleModel(nun.Module): 
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nun.Linear(10, 1)
    def forward(self, x): 
        return self.linear(x)


model= SimpleModel()
model.linear

x = torch.randn(1, 10)
t1 = x.to(torch.float)

with torch.no_grad():
    prediction = model(t1).tolist()
    
print(prediction)

model= SimpleModel()
torch.save(model.state_dict(),'model.pth')