Data Augmentation
Concept and Principle
我们收集的训练数据通常很难覆盖到未来可能部署的全部场景(比如人脸识别的应用可能会部署到不同摄像头状况、天气、时间等的场景)。数据增强则在一个已有的数据上做数据变换,起到增大数据集的作用,使其有更好的多样性。
数据增强
- 数据增强只在训练时进行
- 一般采用在线生成的方式
- 数据增强假设测试环境中会出现增强后的数据,如果测试环境和训练集高度一致则没有必要做数据增强
数据增强方法
- 翻转(上下、左右翻转)
- 切割(随机高宽比、大小、位置切割一块,然后再变为固定大小)
- 颜色(色调、饱和度、亮度)
- 其他(高斯模糊、锐化、遮挡等)
Implementation
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| from torchvision import transforms import torch from matplotlib import pyplot as plt import d2l
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): figsize = (num_cols * scale, num_rows * scale) _, axes = plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): ax.imshow(transforms.ToPILImage()(img)) else: ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) plt.show() return axes
train_iter,test_iter=d2l.load_data_fashion_mnist(4,(224,224)) X,y=next(iter(test_iter))
show_images(X,2,2,scale=1)
trans=transforms.RandomErasing(1) show_images(trans(X),2,2,scale=1)
import torchvision from PIL import Image test_data=torchvision.datasets.FashionMNIST( root="./dataset",train=False, download=True,transform=transforms.ToTensor() )
print(test_data[0][0].shape)
image=transforms.ToPILImage()(test_data[0][0]) image.show() import torchvision test_data=torchvision.datasets.FashionMNIST( root="./dataset",train=False, download=True )
d2l.show_images([test_data[i][0] for i in range(32)], 4, 8, scale=0.8)
image = Image.open(r"./1.jpg") print(image)
print(transforms.ToTensor()(image))
|