본문 바로가기

컴퓨터 과학/인공지능

PyTorch DataLoader: 실무에 가까운 고급 사용법과 예제

반응형

DataLoader는 단순히 데이터를 배치 단위로 나눠주는 것 이상으로,
데이터 전처리부터 성능 최적화까지 다양한 역할을 할 수 있어요.

이번 글에서는 다음과 같은 심화 기능들을 다뤄볼 거예요:

 

🧠 다룰 내용

  1. collate_fn으로 배치 구성 방식 커스터마이징
  2. num_workers로 데이터 로딩 속도 높이기
  3. 불균형 데이터에 대응하기 위한 WeightedRandomSampler
  4. 시퀀스 길이가 다른 텍스트 데이터 패딩 처리
  5. (보너스) persistent_workers, pin_memory 등 옵션 이해

1. 🧩 collate_fn: 배치 구성 방식을 직접 정의

collate_fn은 DataLoader가 개별 데이터를 배치로 묶을 때 묶는 방식을 커스터마이징할 수 있게 해줘요.

💡 예제: 텍스트 길이에 따라 패딩 추가하기 (시퀀스 길이 다를 때)

from torch.utils.data import Dataset, DataLoader
import torch
import random

class TextDataset(Dataset):
    def __init__(self):
        self.samples = [
            [1, 2, 3],
            [4, 5],
            [6, 7, 8, 9],
            [10]
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx])

# collate_fn 정의
def pad_collate(batch):
    batch = [x for x in batch]
    return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)

dataloader = DataLoader(TextDataset(), batch_size=2, collate_fn=pad_collate)

for batch in dataloader:
    print(batch)

 

2. ⚙️ num_workers: 멀티 프로세싱으로 데이터 로딩 가속화

  • 기본값은 0 (메인 프로세스에서 데이터 로딩)
  • num_workers > 0이면 백그라운드에서 병렬로 데이터 로딩해서 빠르게 처리할 수 있음
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

💡 CPU 코어 수를 고려해서 2~4 정도부터 실험해보는 게 좋아요.

 

 

3. ⚖️ WeightedRandomSampler: 불균형 데이터 문제 해결

클래스가 한쪽에 치우쳐 있는 경우, 균형 있게 샘플링하려면 WeightedRandomSampler를 써요.

from torch.utils.data import WeightedRandomSampler

labels = [0]*90 + [1]*10
class_counts = torch.bincount(torch.tensor(labels))
weights = 1. / class_counts.float()
sample_weights = [weights[label] for label in labels]

sampler = WeightedRandomSampler(sample_weights, num_samples=100, replacement=True)
dataloader = DataLoader(list(zip(range(100), labels)), sampler=sampler, batch_size=10)

for batch in dataloader:
    print([label for _, label in batch])
    break

🔎 더 고르게 0과 1이 섞이도록 샘플링돼요!

 

 

4. 🧵 시퀀스 패딩 + 정렬 (자연어처리 전용)

def pad_and_sort(batch):
    batch.sort(key=lambda x: len(x), reverse=True)
    padded = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
    lengths = [len(x) for x in batch]
    return padded, lengths

LSTM/RNN 사용 시 pack_padded_sequence와 함께 유용하게 써요.

 

 

5. 💡 기타 옵션들

옵션설명
pin_memory=True GPU로 빠르게 복사하기 위한 고정 메모리 사용 (속도 향상)
persistent_workers=True num_workers > 0일 때, epoch마다 worker를 재시작하지 않음
drop_last=True 마지막 배치가 작으면 버림 (모델 안정성에 유리할 수도 있음)

📌 요약 정리

기능설명실무 활용
collate_fn 배치 구성 방식 커스터마이징 텍스트 길이 정리, 딕셔너리 처리
num_workers 멀티프로세싱 데이터 로딩 큰 데이터셋 빠르게 처리
WeightedRandomSampler 클래스 불균형 해소 의료/금융/분류 모델
pin_memory GPU 속도 최적화 학습 속도 개선
drop_last 배치 크기 통일 LSTM/RNN에서 안정적 학습

🚀 마무리

PyTorch의 DataLoader는 단순한 반복문을 넘어서
실제 데이터 상황에 맞게 최적화된 학습 파이프라인을 구성할 수 있는 강력한 도구예요!

반응형