数据扩充和增广

chenpaopao

最近在学习 torch,对于图像数据的预处理, torchvision 提供了torchvision.transforms 模块,用于预处理。

  1. 1. 裁剪——Crop 中心裁剪:transforms.CenterCrop 随机裁剪:transforms.RandomCrop 随机长宽比裁剪:transforms.RandomResizedCrop 上下左右中心裁剪:transforms.FiveCrop 上下左右中心裁剪后翻转,transforms.TenCrop
  2. 2. 翻转和旋转——Flip and Rotation 依概率p水平翻转:transforms.RandomHorizontalFlip(p=0.5) 依概率p垂直翻转:transforms.RandomVerticalFlip(p=0.5) 随机旋转:transforms.RandomRotation
  3. 3. 图像变换 resize:transforms.Resize 标准化:transforms.Normalize 转为tensor,并归一化至[0-1]:transforms.ToTensor 填充:transforms.Pad 修改亮度、对比度和饱和度:transforms.ColorJitter 转灰度图:transforms.Grayscale 线性变换:transforms.LinearTransformation() 仿射变换:transforms.RandomAffine 依概率p转为灰度图:transforms.RandomGrayscale 将数据转换为PILImage:transforms.ToPILImage transforms.Lambda:Apply a user-defined lambda as a transform.
  4. 4. 对transforms操作,使数据增强更灵活 transforms.RandomChoice(transforms), 从给定的一系列transforms中选一个进行操作 transforms.RandomApply(transforms, p=0.5),给一个transform加上概率,依概率进行操作 transforms.RandomOrder,将transforms中的操作随机打乱

此外,还提供了 torchvision.transforms.Compose( ),可以同时传递多个函数

mytransform = transforms.Compose([
transforms.ToTensor()
]
)

# torch.utils.data.DataLoader
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
>>> transforms.Compose([ 
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ])

作为 Dataset类的参数传递 :

torchvision.datasets.Caltech101(root: strtarget_type: Union[List[str], str] = ‘category’transform: Optional[Callable] = Nonetarget_transform: Optional[Callable] = Nonedownload: bool = False)

或者自定义的类:
(自己实现torchvision.datasets.CIFAR10的功能)

(自己实现torchvision.datasets.CIFAR10的功能)
import os
import torch
import torch.utils.data as data
from PIL import Image

def default_loader(path):
return Image.open(path).convert('RGB')

class myImageFloder(data.Dataset):
def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
fh = open(label)
c=0
imgs=[]
class_names=[]
for line in fh.readlines():
if c==0:
class_names=[n.strip() for n in line.rstrip().split('    ')]
else:
cls = line.split()
fn = cls.pop(0)
if os.path.isfile(os.path.join(root, fn)):
imgs.append((fn, tuple([float(v) for v in cls])))
c=c+1
self.root = root
self.imgs = imgs
self.classes = class_names
self.transform = transform
self.target_transform = target_transform
self.loader = loader

def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(os.path.join(self.root, fn))
if self.transform is not None:
img = self.transform(img)
return img, torch.Tensor(label)

def __len__(self):
return len(self.imgs)
def getName(self):
return self.classes

实例化torch.utils.data.DataLoader

mytransform = transforms.Compose([
transforms.ToTensor()
]
)

# torch.utils.data.DataLoader
imgLoader = torch.utils.data.DataLoader(
myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ),
batch_size= 2, shuffle= False, num_workers= 2)

for i, data in enumerate(imgLoader, 0):
print(data[i][0])
# opencv
img2 = data[i][0].numpy()*255
img2 = img2.astype('uint8')
img2 = np.transpose(img2, (1,2,0))
img2=img2[:,:,::-1]#RGB->BGR
cv2.imshow('img2', img2)
cv2.waitKey()
break

2 使用Python+OpenCV进行数据扩充(适用于目标检测)

https://pythonmana.com/2021/12/202112131040182515.html

下面内容来自

数据扩充是一种增加数据集多样性的技术,无需收集更多真实数据,但仍有助于提高模型精度并防止模型过拟合。

数据扩充方法包括:

  1. 随机裁剪
  2. Cutout
  3. 颜色抖动
  4. 增加噪音
  5. 过滤
import os

import cv2

import numpy as np

import random


def file_lines_to_list(path):

    '''

    ### 在TXT文件里的行转换为列表 ###

    path: 文件路径

    '''

    with open(path) as f:

        content = f.readlines()

    content = [(x.strip()).split() for x in content]

    return content


def get_file_name(path):

    
'''

    ### 获取Filepath的文件名 ###

    path: 文件路径

    '''

    basename = os.path.basename(path)

    onlyname = os.path.splitext(basename)[0]

    return onlyname


def write_anno_to_txt(boxes, filepath):

    
'''

    ### 给TXT文件写注释 ###

    boxes: format [[obj x1 y1 x2 y2],...]

    filepath: 文件路径
    '''

    txt_file = open(filepath, "w")

    for box in boxes:

        print(box[0], int(box[1]), int(box[2]), int(box[3]), int(box[4]), file=txt_file)

    txt_file.close()

随机裁剪

随机裁剪随机选择一个区域并进行裁剪以生成新的数据样本,裁剪后的区域应具有与原始图像相同的宽高比,以保持对象的形状。

def randomcrop(img, gt_boxes, scale=0.5):

    
'''

    ### 随机裁剪 ###

    img: 图像

    gt_boxes: format [[obj x1 y1 x2 y2],...]

    scale: 裁剪区域百分比
    '''


    # 裁剪

    height, width = int(img.shape[0]*scale), int(img.shape[1]*scale)

    x = random.randint(0, img.shape[1] - int(width))

    y = random.randint(0, img.shape[0] - int(height))

    cropped = img[y:y+height, x:x+width]

    resized = cv2.resize(cropped, (img.shape[1], img.shape[0]))


    # 修改注释

    new_boxes=[]

    for box in gt_boxes:

        obj_name = box[0]

        x1 = int(box[1])

        y1 = int(box[2])

        x2 = int(box[3])

        y2 = int(box[4])

        x1, x2 = x1-x, x2-x

        y1, y2 = y1-y, y2-y

        x1, y1, x2, y2 = x1/scale, y1/scale, x2/scale, y2/scale

        if (x1<img.shape[1] and y1<img.shape[0]) and (x2>0 and y2>0):

            if x1<0: x1=0

            if y1<0: y1=0

            if x2>img.shape[1]: x2=img.shape[1]

            if y2>img.shape[0]: y2=img.shape[0]

            new_boxes.append([obj_name, x1, y1, x2, y2])

    return resized, new_boxes

Cutout

Terrance DeVries和Graham W.Taylor在2017年的论文中介绍了Cutout,它是一种简单的正则化技术,用于在训练过程中随机屏蔽输入的方块区域,可用于提高卷积神经网络的鲁棒性和整体性能。这种方法不仅非常容易实现,而且还表明它可以与现有形式的数据扩充和其他正则化工具结合使用,以进一步提高模型性能。如本文所述,剪切用于提高图像识别(分类)的准确性,因此,如果我们将相同的方案部署到对象检测数据集中,可能会导致丢失对象的问题,尤其是小对象。

剪切输出是新生成的图像,我们不移除对象或更改图像大小,则生成图像的注释与原始图像相同。

def cutout(img, gt_boxes, amount=0.5):

    
'''

    ### Cutout ###

    img: 图像

    gt_boxes: format [[obj x1 y1 x2 y2],...]

    amount: 蒙版数量/对象数量
    '''

    out = img.copy()

    ran_select = random.sample(gt_boxes, round(amount*len(gt_boxes)))


    for box in ran_select:

        x1 = int(box[1])

        y1 = int(box[2])

        x2 = int(box[3])

        y2 = int(box[4])

        mask_w = int((x2 - x1)*0.5)

        mask_h = int((y2 - y1)*0.5)

        mask_x1 = random.randint(x1, x2 - mask_w)

        mask_y1 = random.randint(y1, y2 - mask_h)

        mask_x2 = mask_x1 + mask_w

        mask_y2 = mask_y1 + mask_h

        cv2.rectangle(out, (mask_x1, mask_y1), (mask_x2, mask_y2), (0, 0, 0), thickness=-1)

    return out

颜色抖动

ColorJitter是另一种简单的图像数据增强,我们可以随机改变图像的亮度、对比度和饱和度。我相信这个技术很容易被大多数读者理解。

def colorjitter(img, cj_type="b"):

    
'''

    ### 不同的颜色抖动 ###

    img: 图像

    cj_type: {b: brightness, s: saturation, c: constast}
    '''

    if cj_type == "b":

        # value = random.randint(-50, 50)

        value = np.random.choice(np.array([-50, -40, -30, 30, 40, 50]))

        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

        h, s, v = cv2.split(hsv)

        if value >= 0:

            lim = 255 - value

            v[v > lim] = 255

            v[v <= lim] += value

        else:

            lim = np.absolute(value)

            v[v < lim] = 0

            v[v >= lim] -= np.absolute(value)


        final_hsv = cv2.merge((h, s, v))

        img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)

        return img


    elif cj_type == "s":

        # value = random.randint(-50, 50)

        value = np.random.choice(np.array([-50, -40, -30, 30, 40, 50]))

        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)

        h, s, v = cv2.split(hsv)

        if value >= 0:

            lim = 255 - value

            s[s > lim] = 255

            s[s <= lim] += value

        else:

            lim = np.absolute(value)

            s[s < lim] = 0

            s[s >= lim] -= np.absolute(value)


        final_hsv = cv2.merge((h, s, v))

        img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)

        return img


    elif cj_type == "c":

        brightness = 10

        contrast = random.randint(40, 100)

        dummy = np.int16(img)

        dummy = dummy * (contrast/127+1) - contrast + brightness

        dummy = np.clip(dummy, 0, 255)

        img = np.uint8(dummy)

        return img

增加噪声

在一般意义上,噪声被认为是图像中的一个意外因素,然而,几种类型的噪声(例如高斯噪声、椒盐噪声)可用于数据增强,在深度学习中添加噪声是一种非常简单和有益的数据增强方法。

对于那些无法识别高斯噪声和椒盐噪声之间差异的人,高斯噪声的值范围为0到255,具体取决于配置,因此,在RGB图像中,高斯噪声像素可以是任何颜色。相比之下,椒盐噪波像素只能有两个值0或255,分别对应于黑色(PEPER)或白色(salt)。

def noisy(img, noise_type="gauss"):

    
'''

    ### 添加噪声 ###

    img: 图像

    cj_type: {gauss: gaussian, sp: salt & pepper}
    '''

    if noise_type == "gauss":

        image=img.copy() 

        mean=0

        st=0.7

        gauss = np.random.normal(mean,st,image.shape)

        gauss = gauss.astype('uint8')

        image = cv2.add(image,gauss)

        return image


    elif noise_type == "sp":

        image=img.copy() 

        prob = 0.05

        if len(image.shape) == 2:

            black = 0

            white = 255            

        else:

            colorspace = image.shape[2]

            if colorspace == 3:  # RGB

                black = np.array([0, 0, 0], dtype='uint8')

                white = np.array([255, 255, 255], dtype='uint8')

            else:  # RGBA

                black = np.array([0, 0, 0, 255], dtype='uint8')

                white = np.array([255, 255, 255, 255], dtype='uint8')

        probs = np.random.random(image.shape[:2])

        image[probs < (prob / 2)] = black

        image[probs > 1 - (prob / 2)] = white

        return image

滤波

本文介绍的最后一个数据扩充过程是滤波。与添加噪声类似,滤波也简单且易于实现。实现中使用的三种类型的滤波包括模糊(平均)、高斯和中值。

def filters(img, f_type = "blur"):

    
'''

    ### 滤波 ###

    img: 图像

    f_type: {blur: blur, gaussian: gaussian, median: median}
    '''

    if f_type == "blur":

        image=img.copy()

        fsize = 9

        return cv2.blur(image,(fsize,fsize))


    elif f_type == "gaussian":

        image=img.copy()

        fsize = 9

        return cv2.GaussianBlur(image, (fsize, fsize), 0)


    elif f_type == "median":

        image=img.copy()

        fsize = 9

        return cv2.medianBlur(image, fsize)

上述内容可以在这里找到完整实现

https://github.com/tranleanh/data-augmentation

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注