다시 이음

Pytorch Dataset / Dataloader 본문

Pre_Onboarding by Wanted(자연어 처리)

Pytorch Dataset / Dataloader

Taeho(Damon) 2022. 3. 2. 18:52

안녕하세요.

 

오늘은 Pytorch에서 데이터를 전처리하고 배치화 하는 클래스를 제공합니다.

 

Dataset 클래스는 데이터를 전처리하고 dictionary 또는 list 타입으로 변경할 수 있습니다.
DataLoader 클래스는 데이터 1. 셔플 2. 배치화 3. 멀티 프로세스 기능을 제공합니다.

 

Dataset

  • 모든 custom dataset 클래스는 Dataset() 클래스를 상속받아야 함.
  • __getitem__()와 __len__() 메소드를 반드시 오버라이딩해야 함.
  • DataLoader 클래스가 배치를 만들 때 Dataset 인스턴스의 __getitem__() 메소드를 사용해 데이터에 접근함
  • 해당 Dataset 클래스는 string sequence 데이터를 tokenize & tensorize한다.

 

from torch.utils.data import Dataset

class Custom_Dataset(Dataset):

    def __init__(self, data: Iterator):
    	#BERT모델을 토크나이저로 사용
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
        self.target = []
        self.text = []
        for target, text in data:
            self.target.append(target)
            self.text.append(text)
  
    def __len__(self):
        return len(self.target)

    def __getitem__(self, index):
    	#인덱싱할 수 있도록 로직을 짜야합니다.
        # encode
        token_ids = self.tokenizer.encode(
        text = self.text[index],
        truncation = True,
        )
        # tensorize
        return torch.tensor(token_ids), torch.tensor([self.target[index]])

 

 

Dataloader

공식문서 : https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

 

파라미터

  • dataset : 입력값 데이터셋
    • map-style dataset
    • iterable style dataset : random으로 읽기에 어렵거나, data에 따라 batch_size가 달라지는 데이터(dynamic batch size)에 적합한 데이터셋  
      • __iter__() : 선언 필요
  • batch_size : 배치 사이즈
    • int
  • shuffle : 데이터를 Dataloader에 섞어서 사용할 것인지 설정하는 인수
    • bool
  • sampler : torch.utils.data.Sampler 객체를 사용합니다. // data index iterator
    sampler는 index를 컨트롤하는 방법입니다. 데이터의 index를 원하는 방식대로 조정합니다.

    즉 index를 컨트롤하기 때문에 설정하고 싶다면 shuffle 파라미터는 False(기본값)여야 합니다.
    • SequentialSampler : 항상 같은 순서
    • RandomSampler : 랜덤, replacemetn 여부 선택 가능, 개수 선택 가능
    • SubsetRandomSampler : 랜덤 리스트, 위와 두 조건 불가능
    • WeigthRandomSampler : 가중치에 따른 확률
    • BatchSampler : batch단위로 sampling 가능
    • DistributedSampler : 분산처리 (torch.nn.parallel.DistributedDataParallel과 함께 사용)
  • collate_fn : map-style 데이터셋에서 sample list를 batch 단위로 바꾸기 위해 필요한 기능입니다. zero-padding이나 Variable Size 데이터 등 데이터 사이즈를 맞추기 위해 많이 사용합니다.
#적용 방법
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

train_dataloader = DataLoader(dataset,batch_size=32, sampler=RandomSampler(dataset), collate_fn =None)