001. Torchvision 0.8, GPU에서 돌아가는 Scriptable Transforms 사용해서 데이터 로딩 속도 개선하기
Torchvision 0.8 에서 새롭게 나온 기능중 가장 유용하고 기다려왔던 기능입니다. 바로 Transform이 Tensor, Batch computation, GPU and TorchScript를 지원합니다!
상세한 릴리즈 노트는 여기를 확인해 주세요.
이 포스트에서는 해당 기능을 이용해서 CIFAR, ImageNet을 Loading 하는 코드를 작성하고, 실제로 얼마나 빨라지는지 실험을 해보겠습니다.
본 포스트에서 사용한 모든 코드는 여기에서 확인하세요. Star도 꾹 눌러주세요 ^^
사용법
기존의 Transform은 transform.compose를 이용해서 dataset의 transform 파라미터에 넣어주었습니다.
original_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
original_dataset = CIFAR100('/dataset/CIFAR', transform=original_transforms)
original_loader = torch.utils.data.DataLoader(original_dataset, batch_size=64, shuffle=True, num_workers=4)
새로운 Transform은 약간 방법이 바뀌었는데요.
new_transforms = nn.Sequential(
transforms.RandomHorizontalFlip(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
)
new_transforms = torch.jit.script(new_transforms)
new_dataset = CIFAR100('/dataset/CIFAR', transform=transforms.ToTensor())
new_loader = torch.utils.data.DataLoader(new_dataset, batch_size=64, shuffle=True, num_workers=4)
이 코드에서 큰 차이점은 세가지 입니다.
- nn.Sequential을 사용
- TorchScript로 변환
- dataset의 transform에는 ToTensor() 만 사용
먼저 transform.Compos 대신에 nn.Sequential을 사용합니다. 즉 transform코드가 모델의 모듈처럼 동작하고 모델에도 삽입 될수 있다는 뜻입니다.
두번째로 TorchScript로 변환하여 더 빠르게 인퍼런스 할 수 있도록합니다.
마지막으로 기존의 dataset의 transform에는 ToTensor() 만 사용합니다. 이미지를 로드하고 텐서로만 변경하면 되기 때문이죠.
그렇다면 실제 transform 코드는 어디서 돌아갈까요?
for i, (x, l) in enumerate(new_loader):
x = x.cuda()
new_transforms(x)
바로 모델과 같이 실제 학습 루프에서 돌아가게 됩니다. 이렇게 cuda tensor를 넘겨주면 앞에서 정의했던 transform이 GPU에서 동작합니다.
ImageNet에 적용
하지만 이 코드를 바로 ImageNet에 적용을 하기에 큰 문제가 있었습니다.
바로 ImageNet은 이미지 사이즈가 모두 달라서 ToTensor를 하면 오류가 나게 됩니다.
그렇기 때문에 ToTensor를 하기전에 Resize를 해야 합니다.
new_transforms = torch.nn.Sequential(
transforms.RandomHorizontalFlip(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
new_transforms = torch.jit.script(new_transforms)
new_dataset = ImageNet('/dataset/ImageNet', transform=transforms.Compose([transforms.RandomResizedCrop(224), transforms.ToTensor()]))
new_loader = torch.utils.data.DataLoader(new_dataset, batch_size=64, shuffle=True, num_workers=4)
새로운 transform은 위 예제와 같이 적용합니다. 다른점은 Dataset의 transform 파라미터에 RandomResizedCrop으로 미리 크기를 조정하는 것이죠.
효율성 측정
수정한 코드를 기반으로 얼마나 시간이 줄어드는지 효율성을 측정해 보았습니다.
본 테스트를 진행한 서버 스펙은 다음과 같습니다
- NVIDIA RTX-2080ti x 1
- Intel(R) Core(TM) i9-7900X @ 3.30GHz, vCPU 4 core
- Memory 16G
- NVIDIA-Docker
- Pytorch 1.7.1
- Torchvision 0.8.2
- nvidia Driver 440.100
- CUDA 10.1
- cuDNN 7.6
CIFAR
CIFAR에서는 Batch_size 64, num_workers 4로 총 200 epoch을 세번씩 돌려 평균적으로 걸리는 시간(초)을 측정하였습니다.
Original Transforms | Scriptable Transforms | |
1차 | 708.79s | 480.52s |
2차 | 702.59s | 477.51s |
3차 | 700.88s | 479.93s |
평균 | 704.09s | 479.32s |
Scriptable Transforms가 479.32초로 기존보다 224.77초나 빨랐으며, 약 68%의 로딩속도 향상이 있었습니다.
ImageNet
ImageNet에서는 Batch_size 64, num_workers 4로 총 200 epoch을 세번씩 돌려 평균적으로 걸리는 시간(초)을 측정하였습니다. 이미지넷은 시간이 오래걸려서 총 20019 배치 중 10 배치만 로드하고 스킵하였습니다.
Original Transforms | Scriptable Transforms | |
1차 | 397.07s | 382.62s |
2차 | 401.18s | 382.59s |
3차 | 396.35s | 383.49s |
평균 | 398.2s | 382.9s |
역시 Scriptable Transforms가 빠른것을 볼 수 있는데요. ImageNet실험의 경우 Resize하는 코드가 Scriptable Tansforms외부에 있기 때문에 CPU에서 실행되게 되어 생각보다 성능향상이 크지 않은 것을 볼 수 있습니다.
전체 실험에서 얼마나 시간이 줄어들지 계산하기위해 20019배치를 모두 돌았다고 가정하여 계산을 하면,
기존 코드는 797156.58초(221시간), 새 코드는 766527.51초(213시간)로 30629.07초 약 8.5시간 정도 차이가 나게 됩니다.
결론
Scriptable Transforms은 사용하는 것이 데이터 로딩 속도를 향상시킬 수 있다.
input size가 다른경우에 사용하기 어려운 경우가 있다.
Focus Only AI | Ocean: GPU 클러스터 관리 솔루션 | AI Ocean 대표 신은섭
Ocean 소개글: https://bongjasee.tistory.com/pages/Ocean
Ocean - GPU 클러스터 관리 솔루션
Ocean은 2020년 경희대학교에서 딥러닝 연구를 시작하던 때부터 만들어진 GPU 클러스터 관리 솔루션입니다. # 성과 개발을 시작한 이후 지금까지 MLVC 연구실에서는 Ocean을 통해 다양한 연구, 프로젝
bongjasee.tistory.com
Github:https://github.com/AI-Ocean
AI-Ocean
AI-Ocean has one repository available. Follow their code on GitHub.
github.com