Real Late Starter

[PyTorch] 7. Custom Dataset 본문

PyTorch

[PyTorch] 7. Custom Dataset

조슈아박 2020. 2. 20. 00:31

 

이번 포스트에서는 저번에 올렸던 2020/02/11 - [PyTorch] - [PyTorch] 4-1. 나만의 이미지 데이터셋 만들기

에 이어서 Custom Dataset을 만드는 방법에 대해 알아보겠습니다. 

 

저번에 올렸던 포스트에서는 torch.dataset에서 제공하는 ImageFolder로 이미지데이터셋을 만드는 것을 알아보았습니다.

하지만 우리가 사용하는 데이터가 모두 이미지 데이터는 아니기 때문에 다른 형태의 데이터도 불러오기 위해서는 원하는 형태로 가공할 수 있는 방법을 알아야합니다.

 

커스텀 데이터셋을 어떻게 구성하는가에 대해 알아보겠습니다.

 

1. Dataset & Loader

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(),transform.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

trainset = torchvision.datasets.CIFAR(root = './data',train=True,
                                      download=True,
                                      transform=transform)
trainset = torchvision.datasets.CIFAR(root = './data',train=False,
                                      download=True,
                                      transform=transform)                                  

# torchvision을 통해 dataset을 구성하고 DataLoader를 통해 데이터를 불러오기 쉽게 감싸준다.
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)

torchvision.datasets을 이용해서 trainset과 testset을 만들어줍니다. 이것을 DataLoader을 통해 trainloader와 testloader로 전달합니다. 이 과정에서 batch_size와 shuffle 기능, cpu 사용 코어 갯수 등을 지정해줄 수 있습니다.

 

epoch_num = 3

for epoch in range(epoch_num):
	for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        
        optim.zero_grad()
        out = my_net(inputs)
        loss = loss_function(out, labels)
        loss.backward()
        optim.step()
        
        if i % 100 == 0:
        	print("%d => loss : %3.f" %(i,loss))
print('train over)

만들어논 trainloader와 testloader를 사용해서 반복문에만 넣어주면 간편하게 모델의 학습을 진행할 수 있습니다.

 

2. Dataset의 형태

 

그럼 다음으로 우리가 사용할 dataset의 내부 형태를 살펴보도록 하겠습니다.

# dataset 구성요소

class my_dataset(torch.utils.data.Dataset):

    def __init__(self, x, transforms=None):
    
    def __len__(self):
    
    def __getitem__(self, idx):
    

먼저 데이터셋을 구성할 때는 class 형태로 구성을 해줍니다. 이름을 지정해주는데 여기서는 'my_dataset'이라고 지정하였습니다.

그 다음 괄호 안에 torch.utils.data.Dataset을 입력해줍니다. 클래스 안에 3가지 요소를 넣어서 구성을 해줍니다.

 

각 요소들의 역할을 살펴보겠습니다.

 

1. __init__

- 데이터 셋을 가져와서 전처리를 해주는 기능, 함수들을 담아준다.

 

2. __len__

- 데이터 셋의 전체 길이를 반환한다.

 

3. __getitem__

- 데이터 셋에서 한 개의 데이터를 가져오는 함수를 정의한다.

 

3. 예제를 통해 데이터 셋 구성 확인하기

이번에는 우리가 주로 사용하는 Toy Data라고 할 수 있는 CIFAR-10 데이터셋 코드를 확인하면서 정말로 위의 dataset class와 마찬가지로 구성이 되어있는지 확인해보도록 하겠습니다.

class CIFAR10(VisionDataset):

    base_folder = 'cifar-10-batches-py'
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                if sys.version_info[0] == 2:
                    entry = pickle.load(f)
                else:
                    entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def __getitem__(self, index):

        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

위의 코드는 torchvision 내에 있는 CIFAR-10 dataset을 로드하는 코드입니다. 클래스 안을 보면 위에서 살펴본 것처럼

__init__, __len__, __getitem__이 있는 것을 볼 수 있습니다. (핵심 요소만 남기고 나머지 코드는 삭제하였습니다.)

하나하나씩 살펴보도록 하겠습니다.

 

3-1) __init__

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                if sys.version_info[0] == 2:
                    entry = pickle.load(f)
                else:
                    entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

__init__의 파라미터를 보면 self, root 디렉토리를 정의를 해줍니다. 또한 데이터를 전처리하는 transform을 입력해줄 수 있고, target feature에 대한 transform도 따로 지정해줄 수 있습니다. 다운로드 같은 경우네는 인터넷을 통해 자동을 다운로드를 받을 것인지 아님 이미 받아놓은 데이터를 사용할 것인지에 대한 파라미터입니다.

 

밑을 보시면  train을 입력했을 경우 test를 입력했을 경우에 if문을 사용해서 다른 데이터를 불러올 수 있도록 구성을 해놓았습니다. 

그리고 38번 째 줄에 보면 reshape(-1, 3, 32, 32)으로 전체 데이터의 shape을 재정의한 것을 볼 수 있습니다. 이렇게 구성을 함으로써 좀 더 쉽게 불러올 수 있도록 장치를 해놓았습니다.

 

3-2) __getitem__

    def __getitem__(self, index):

        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

__getitem__은 데이터셋이 가지고있는 데이터를 리턴하는 기능을 합니다. 파라미터로 받는 index를 가지고 img, target을 바로 리턴해주는 모습을 볼 수 있습니다. 전체 데이터가 array형태로 되어있기 때문에 인덱스를 사용해서 바로 불러올 수 있습니다.

 

3-3) __len__

    def __len__(self):
        return len(self.data)

__len__의 경우에는 전체데이터의 길이를 반환해주는 역할을 하기 때문에 파이썬 함수 len()을 사용하여 데이터 길이, 총량을 반환해 줍니다.

 

이상으로 custom dataset을 구성하는 방법에 대해서 알아봤습니다.

 

이 포스트는 김군이(https://www.youtube.com/watch?v=KXiDzNai9tI&t=984s)님의 강의를 듣고 공부하며 정리한 내용을 올린 것입니다.

 

'PyTorch' 카테고리의 다른 글

[PyTorch] 5. Network의 정의  (0) 2020.02.12
[PyTorch] 4-1. 나만의 이미지 데이터셋 만들기  (22) 2020.02.11
[PyTorch] 4. Data Loader  (1) 2020.01.23
[PyTorch] 3. nn & nn.functional  (0) 2020.01.22
[PyTorch] 2. Autograd & Variable  (0) 2020.01.18