AlexNet

Concept and Principle

  • 在深度学习之前:
    • 核方法
      • 特征提取
      • 选择核函数
      • 凸优化问题
      • 漂亮的定理
    • 几何学
      • 抽取特征
      • 将计算机视觉问题描述为几何问题(如多相机)
      • 凸优化
      • 漂亮的定理
      • 建立假设模型,若假设满足,效果会很好
    • 特征工程
      • 特征工程(人工特征提取)是关键,不太关心机器学习模型
      • 特征描述子(SIFT、SURF)
      • 视觉词袋(聚类)
      • 最后一般用SVM
  • AlexNet赢得了2012年ImageNet竞赛的冠军,引起了深度学习的热潮,其本质上是一个更深更大的LeNet
    • 主要改进:
      • 丢弃法
      • ReLU
      • MaxPooling
      • 数据增强(截取、调亮度、调色温等)
      • 更深更大
  • 网络结构与复杂度

Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import nn,optim
import d2l

class AlexNet(nn.Module):

def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(1,96,kernel_size=11,stride=4,padding=1)
self.conv2=nn.Conv2d(96,256,kernel_size=5,padding=2)
self.conv3=nn.Conv2d(256,384,kernel_size=3,padding=1)
self.conv4=nn.Conv2d(384,384,kernel_size=3,padding=1)
self.conv5=nn.Conv2d(384,256,kernel_size=3,padding=1)
self.liner1=nn.Linear(6400,4096)
self.liner2=nn.Linear(4096,4096)
self.liner3=nn.Linear(4096,10)

self.pool=nn.MaxPool2d(kernel_size=3,stride=2)
self.flatten=nn.Flatten()
self.relu=nn.ReLU()
self.dropout=nn.Dropout(p=0.5)


def forward(self,x):

x=self.relu(self.pool(self.conv1(x)))

x=self.relu(self.pool(self.conv2(x)))
x=self.relu(self.conv3(x))
x=self.relu(self.conv4(x))
x=self.relu(self.conv5(x))
x=self.pool(x)
x=self.flatten(x)
x=self.dropout((self.liner1(x)))
x=self.dropout((self.liner2(x)))
x=self.liner3(x)

return x

net=AlexNet().to(torch.device("cuda:0"))

# x=torch.rand((1,1,224,224)).cuda()
# print(net(x).size())

# 读取Fashion-MNIST,将图片直接拉伸为224x224
train_iter,test_iter=d2l.load_data_fashion_mnist(256,224)

# d2l.train(
# 25,nn.CrossEntropyLoss(),
# optim.Adam(net.parameters()),
# net,train_iter,save_name="AlexNet",
# device=torch.device("cuda:0"))
d2l.evaluate(
net,test_iter,nn.CrossEntropyLoss(),
param_path="D:\code\machine_learning\limu_d2l\params\AlexNet_5",
device=torch.device("cuda:0")
)