1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| from torchvision import models,transforms from torch import nn,optim import torch import d2l
train_iter,test_iter=d2l.load_data_fashion_mnist(64)
train_aug=transforms.Compose([ transforms.Resize((224,224)), transforms.RandomHorizontalFlip(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]), ])
test_aug=transforms.Compose([ transforms.Resize((224,224)), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ])
finetune_res=models.resnet18(pretrained=True,progress=True)
finetune_res.fc=nn.Linear(finetune_res.fc.in_features,10)
nn.init.xavier_uniform_(finetune_res.fc.weight)
loss_f=nn.CrossEntropyLoss()
opt=optim.Adam(finetune_res.parameters(),weight_decay=0.1,lr=5e-5)
d2l.train( 5,loss_f,opt,finetune_res,train_iter, save_name="res18_pretrained",device=torch.device("cuda:0"), aug=train_aug )
|