ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [yongggg's] pytorch 딥러닝 GPU 메모리 관리 (IterableDataset)
    GPUTraining & View 2022. 5. 3. 16:08

    안녕하세요 이번 장에서는 모델을 학습하고 싶은데, 너무 큰 학습 데이터 때문에, dataloader가 GPU에 올라가지 않을 때 사용할 수 있는 방법을 설명드리겠습니다. 물론 batch를 낮춤으로써 이를 해결할 수 있다면, 이 내용이 필요없으실 수 있습니다. 하지만, 생각보다 너무 적은 배치를 사용하지 못할 경우, 혹은 GPU 서버를 다른 분들과 공유해서 사용할 때, 메모리 issue가 있으신 분들에게는 해당 내용이 도움이 될 것이라고 생각합니다. 이제 본격적으로 설명해 보겠습니다!

    1. Pytorch Dataset

    먼저 official pytorch 포럼에서는 두 가지 Dataset 모듈(Map-style Dataset, Iterable Dataset)이 있다.  

    • map-style Dataset: pytorch 1.2 버전 이하에서 사용되던 모듈이며, 지금도 가장 널리 쓰이는 방식이다. 
    • Iterable Dataset: 개발되어야 할 기능이 더 필요한 모듈이며, 사용하기 번거롭지만, 메모리 효율성이 높다.

    1.1 Map-Style Dataset

    위에서 간략하게 설명한 것과 같이 map-style Dataset은 지금 시점에도 널리 쓰이는 class이며, 모든 데이터를 memory에 올릴 수 있을 때 사용하는 일반적인 경우이다. 다음 코드와 같이 torch.utils.data.Dataset 클래스를 상속 받아 지정된 몇 개의 함수를 추가하여 구현하면 된다. 

     

    from torch.utils.data import Dataset
    
    class MyDataset(Dataset):
    
    	def __init__(self, data):
        	self.data =data
          
     	def __len__(self):
        	return(len(self.data))
        
        def __getitem__(self, index):
        	return self.data['text'][index]

     

    1.2 Iterable Dataset

    Map-style Dataset은 memory에 모든 데이터를 업로드 하는 방법이라고 소개했다. 이 Dataset을 사용하여 문제가 없다면 상관 없지만, 위에서 말한 것과 같이 학습 데이터가 memory에 다 올라가지 않는 경우가 발생할 수 있다. 이러한 문제를 해결할 수 있는 방법 중에 하나로 같은 torch.utils.data.IterableDataset을 사용하는 방법이 있다. 아래 코드와 같이 torch.utils.data.IterableDataset 클래스를 상속받은 후, iter 함수를 정의해주면 된다. 아래 코드는 pretrained Aibrilbert (SK C&C 자연어 처리 모델)에 들어가기 전의 Dataset을 직접 custom한 코드의 일부이다.

    batch를 생성할 때, Dataset 클래스는 __getitem__ 의 index를 사용한다면, IterableDataset은 next() 인자로 __iter__를 처리하기 때문에, sampler 사용이 어렵다. 따라서, shuffling을 Dataloader에서 하지 못하고, dataset을 미리 shuffling 해주어야한다.

     

    from torch.utils.data import IterableDataset
    
    class MYIterableDataset(IterableDataset):
    	
        def __init__(self, encodings, labels):
            self.encodings = encodings
            self.labels =labels
    		
            ...
            
            if self.labels is not None:
                items = {key: torch.tensor(val) for key, val in self.encodings.items()}
                items['labels'] = torch.tensor(self.labels) # .unsqueeze(-1).repeat(1,256)
                self.data = torch.cat((items['input_ids'], items['token_type_ids'], items['attention_mask'], items['labels'].unsqueeze(-1)), axis=1)
    			
                ...
                
        def __iter__(self):
            for d in self.data:
    
                worker = torch.utils.data.get_worker_info()
                worker_id = worker.id if worker is not None else -1
                start = time.time()
                time.sleep(0.1)
                end = time.time()
                
                yield d, worker_id, start, end
        
        def __len__(self):
            return self.data.shape[0]

     

    official pytorch 포럼에서는 num_workers>0 조건에서 Iterable Dataset을 사용할 때,  worker_init_fn option을 사용하는 것을 제안했다.

    먼저, num_workers는 dataset을 불러올 때, 몇 개의 subprocess를 사용할 지에 대한 인자이다. num_workers가 0 또는 1은 main process에서 데이터를 불러오는 것과 모델에 pass 하는 작업을 모두 수행한다는 뜻이고 num_workers가 2라면 subprocess 2개에서 데이터를 불러오며, main process에서는 subprocess에서 불러온 데이터를 모델에 pass하는 역할만을 담당한다. 따라서 num_workers가 > 1일 때, data loading 시 병목현상이 줄어들며, GPU이용량을 끌어올릴 수 있다.

    official pytorch 포럼의 worker_init_fn 함수는 다음과 같다.

     

    def worker_init_fn(_):
        worker_info = torch.utils.data.get_worker_info()
        
        dataset = worker_info.dataset
        worker_id = worker_info.id
        split_size = len(dataset.data) // worker_info.num_workers
        
        dataset.data = dataset.data[worker_id * split_size: (worker_id + 1) * split_size]

     

    Dataset을 DataLoader로 불러올 때, 다음과 같이 위 함수를 옵션으로 같이 넣어주고 loader에서 batch 등의 인자를 가져와 사용하면 된다. 만약 num_workers > 1 이고 worker_init_fn 함수를 같이 사용하지 않을 경우에는 각 subprocess에서 중복되는 data를 불러오는 현상이 발생하기 때문에 꼭 이 함수를 같이 써야한다는 것을 명심해야한다.

     

    loader = DataLoader(iterable_dataset, batch_size=4, num_workers=2, worker_init_fn=worker_init_fn)
    for d in loader:
        batch, worker_ids, starts, ends = d
        print(batch, worker_ids)

     

    이렇게 메모리가 터지거나, 다른 사람과 함께 서버를 쓰는 경우에 GPU 용량을 관리할 수 있는 스킬 중에 하나를 소개했습니다. 제가 쓴 글 중에, multi gpu를 활용하는 내용과 이를 함께 적용한다면, 자신이 학습하고자 하는 모델을 더욱 수월하게 학습할 것이라고 기대합니다. 많은 도움 되셨길 바랍니다^^!!

Designed by Tistory.