import torch | |
import torch.nn as nun | |
class SimpleModel(nun.Module): | |
def __init__(self): | |
super(SimpleModel, self).__init__() | |
self.linear = nun.Linear(10, 1) | |
def forward(self, x): | |
return self.linear(x) | |
model= SimpleModel() | |
model.linear | |
x = torch.randn(1, 10) | |
t1 = x.to(torch.float) | |
with torch.no_grad(): | |
prediction = model(t1).tolist() | |
print(prediction) | |
model= SimpleModel() | |
torch.save(model.state_dict(),'model.pth') | |