1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| import torch from torch import nn
x=torch.arange(4)
torch.save(x,'x_file') y=torch.load('x_file') print(x==y)
a=torch.arange(5) b=torch.arange(5) torch.save([a,b],"list") torch.save({'a':a,'b':b},"dict")
print(torch.load("list")) print(torch.load("dict"))
net=nn.Sequential(nn.Conv2d(1,2,5),nn.Flatten(),nn.Linear(288,64)) print(net(torch.normal(0,0.5,size=(1,1,16,16)))) print(net.state_dict()) torch.save(net.state_dict(),'net.params') new_net=nn.Sequential(nn.Conv2d(1,2,5),nn.Flatten(),nn.Linear(288,64)) new_net.load_state_dict(torch.load('net.params')) print(new_net==net)
X=torch.normal(0,0.5,size=(1,1,16,16)) print(new_net(X)==net(X))
|