Data Augmentation

Concept and Principle

我们收集的训练数据通常很难覆盖到未来可能部署的全部场景(比如人脸识别的应用可能会部署到不同摄像头状况、天气、时间等的场景)。数据增强则在一个已有的数据上做数据变换,起到增大数据集的作用,使其有更好的多样性。

  • 数据增强

    • 数据增强只在训练时进行
    • 一般采用在线生成的方式
    • 数据增强假设测试环境中会出现增强后的数据,如果测试环境和训练集高度一致则没有必要做数据增强
  • 数据增强方法

    • 翻转(上下、左右翻转)
    • 切割(随机高宽比、大小、位置切割一块,然后再变为固定大小)
    • 颜色(色调、饱和度、亮度)
    • 其他(高斯模糊、锐化、遮挡等)

Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# pytorch在提供transforms模块中提供了很多数据增广函数
from torchvision import transforms
import torch
from matplotlib import pyplot as plt
import d2l

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):

figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(transforms.ToPILImage()(img))
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
plt.show()
return axes

train_iter,test_iter=d2l.load_data_fashion_mnist(4,(224,224))
X,y=next(iter(test_iter))

# 显示tensor形式的图片
# X=X.view(4,224,224)

show_images(X,2,2,scale=1)

# 数据增广
# trans=transforms.RandomCrop((20,20))
# trans=transforms.RandomHorizontalFlip()
# trans=transforms.GaussianBlur(5)
trans=transforms.RandomErasing(1)
show_images(trans(X),2,2,scale=1)

# 显示使用PIL读进内存的图片
import torchvision
from PIL import Image
test_data=torchvision.datasets.FashionMNIST(
root="./dataset",train=False,
download=True,transform=transforms.ToTensor()
)

print(test_data[0][0].shape)
# 将Tensor形式转换为PIL形式
image=transforms.ToPILImage()(test_data[0][0])
image.show()
import torchvision
test_data=torchvision.datasets.FashionMNIST(
root="./dataset",train=False,
download=True
)
# 第i张图片test_data[i][0],test_data[i][1]是第一张图片的标签
d2l.show_images([test_data[i][0] for i in range(32)], 4, 8, scale=0.8)

image = Image.open(r"./1.jpg")
print(image)
# 将PIL形式转换为Tensor形式
print(transforms.ToTensor()(image))