Dataloader에서 특정 axis를 미니 배치로 나눌 수 있을까요?

Dataloader에서 특정 axis를 미니 배치로 나눌 수 있을까요?

예를 들어 shape이 (4, 100, 10) 인 텐서가 있다고 치고

미니배치 사이즈가 25 → 각 미니배치 별로 (4, 25, 10) 인 텐서를 4개를 만든다고 하는게 가능할까요?

채널이 4인 CNN을 만들고 싶은데, axis=0 이 채널이 되버리면서 미니배치가 꼬이게 되었습니다.

안녕하세요, @jyj7913 님.

사용자 정의 Dataloader를 만드시고, 내부에서 np.split() 또는 torch.split()을 사용하면 어떠실련지요?

아래와 같이 (3, 6, 4)의 크기를 갖는 ndarray가 있을 때,

>>> import numpy as np
>>> a = np.arange(72).reshape(3,6,4)
>>> a.shape
(3, 6, 4)
>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]],

       [[24, 25, 26, 27],
        [28, 29, 30, 31],
        [32, 33, 34, 35],
        [36, 37, 38, 39],
        [40, 41, 42, 43],
        [44, 45, 46, 47]],

       [[48, 49, 50, 51],
        [52, 53, 54, 55],
        [56, 57, 58, 59],
        [60, 61, 62, 63],
        [64, 65, 66, 67],
        [68, 69, 70, 71]]])
>>>

아래와 같이 np.split()axis 매개변수를 지정하여 (3, 2, 4)의 크기를 갖는 3개의 ndarray로 나눌 수 있습니다.

>>> b = np.split(a, 3, axis=1)
>>> len(b)
3
>>> b[0].shape
(3, 2, 4)
>>> b[1].shape
(3, 2, 4)
>>> b[2].shape
(3, 2, 4)
>>> b
[array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],

       [[24, 25, 26, 27],
        [28, 29, 30, 31]],

       [[48, 49, 50, 51],
        [52, 53, 54, 55]]]), array([[[ 8,  9, 10, 11],
        [12, 13, 14, 15]],

       [[32, 33, 34, 35],
        [36, 37, 38, 39]],

       [[56, 57, 58, 59],
        [60, 61, 62, 63]]]), array([[[16, 17, 18, 19],
        [20, 21, 22, 23]],

       [[40, 41, 42, 43],
        [44, 45, 46, 47]],

       [[64, 65, 66, 67],
        [68, 69, 70, 71]]])]
>>>

np.split()의 매개변수 관련 설명은 아래 링크에서 확인해보실 수 있는데요,
https://numpy.org/doc/stable/reference/generated/numpy.split.html

위에서 사용한 np.split(a, 3, axis=1)의 경우에는 주어진 a라는 ndarray2번째 axis에 대해서 3토막으로 나누는 것입니다.


PyTorch에도 비슷한 torch.split()이라는 함수가 있는데요, 매개변수는 살짝 다릅니다.

위에서 만든 동일한 a라는 (3, 6, 4) 크기를 갖는 ndarray를 사용하여 c라는 Tensor에 대해서,

>>> import torch
>>> c = torch.tensor(a)
>>> c.shape
torch.Size([3, 6, 4])
>>> c
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]],

        [[24, 25, 26, 27],
         [28, 29, 30, 31],
         [32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43],
         [44, 45, 46, 47]],

        [[48, 49, 50, 51],
         [52, 53, 54, 55],
         [56, 57, 58, 59],
         [60, 61, 62, 63],
         [64, 65, 66, 67],
         [68, 69, 70, 71]]])
>>>

아래와 같이 torch.split()dim 매개변수를 지정하여 (3, 2, 4)의 크기를 갖는 3개의 텐서로 나눌 수 있습니다.

>>> d = torch.split(c, 2, dim=1)
>>> len(d)
3
>>> d[0].shape
torch.Size([3, 2, 4])
>>> d[1].shape
torch.Size([3, 2, 4])
>>> d[2].shape
torch.Size([3, 2, 4])
>>> d
(tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[24, 25, 26, 27],
         [28, 29, 30, 31]],

        [[48, 49, 50, 51],
         [52, 53, 54, 55]]]), tensor([[[ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[32, 33, 34, 35],
         [36, 37, 38, 39]],

        [[56, 57, 58, 59],
         [60, 61, 62, 63]]]), tensor([[[16, 17, 18, 19],
         [20, 21, 22, 23]],

        [[40, 41, 42, 43],
         [44, 45, 46, 47]],

        [[64, 65, 66, 67],
         [68, 69, 70, 71]]]))
>>>

torch.split()에 대한 설명은 아래 문서에서 확인해보실 수 있는데요,
https://pytorch.org/docs/stable/generated/torch.split.html

위에서 사용한 torch.split(c, 2, dim=1)의 경우에는 주어진 c라는 Tensor2번째 dim에 대해서 2개씩 나누는 것입니다.

(= 즉, np.split()에서는 나누고자 하는 개수를, torch.split()에서는 나눠진 각 토막(chunk)의 크기를 지정하는 부분이 다릅니다.)


그 외, Dataloader를 정의하는 방법은 아래 튜토리얼 문서를 참고해보시면 좋을 것 같습니다.

https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html#id9

혹시 의도하셨던 것과 다른 내용이라면 알려주시기를 부탁드립니다. :slight_smile:

1개의 좋아요

답변 감사합니다!

예시까지 올려주셔서 쉽게 이해되었고, 제가 원하는 답이 맞는 것 같습니다.

좋은 하루 보내시길 바랍니다.

다시 한 번 감사드립니다.

1개의 좋아요

앗, 도움이 되셨다니 다행입니다 :smiley: