카테고리 없음

[오늘의 코딩] collate_fn 정의하기

바모이 2022. 11. 24. 02:09

torch.utils.data.Dataset에 data를 정의하여 넘기면, torch.utils.data.Dataloader의 batch_sampler 함수에서 data를 배치화시켜준다.

그러나 텍스트 데이터와 같이 batch 내의 data 길이가 다르다면 (즉, variable length data라면) padding을 통해 batch 내의 data 길이를 맞추어야 한다. 그렇지 않다면, 아래와 같은 오류가 발생할 것이다!

torch.utils.data.Dataloader에 있는 collate_fn 함수를 재정의하여 이러한 오류를 피할 수 있다.

 

RuntimeError: stack expects each tensor to be equal size, but got [593, 202] at entry 0 and [610, 202] at entry 1

 

 

def my_collate(samples):
    inputs = [sample['input'] for sample in samples]
    labels = [sample['label'] for sample in samples]
    pad_inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    pad_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)
    return pad_inputs, pad_labels
    

train_dataloader = DataLoader(train_dataset, batch_size=args.batch, collate_fn=my_collate)

 

NER task를 진행 중인 코드이기 때문에, input과 label 모두 길이가 가변적이다. 

label의 경우는 0부터 18번 라벨까지 쓰고 있기 때문에, -1로 padding하였다.

 

torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)

가변 길이의 텐서를 동일한 길이로 맞추어 padding하는 함수이다. 

batch_first가 True가 된다면 텐서의 첫 차원에 batch가 위치한다.