torch.cat()과 torch.stack()은 어떻게 다른가요?

공식 홈페이지StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.

다음 링크에서 원문을 함께 찾아보실 수 있습니다.


질문

  • torch.cat()torch.stack()은 어떤 차이점이 있나요?
1개의 좋아요

답변

  • torch.cat()주어진 차원을 기준으로 주어진 텐서들을 붙입(concatenate)니다.
    torch.stack()새로운 차원으로 주어진 텐서들을 붙입니다.
  • 따라서, (3, 4)의 크기(shape)를 갖는 2개의 텐서 AB를 붙이는 경우,
    torch.cat([A, B], dim=0)의 결과는 (6, 4)의 크기(shape)를 갖고,
    torch.stack([A, B], dim=0)의 결과는 (2, 3, 4)의 크기를 갖습니다.

다른 답변

  • 예를 들어 설명하기 위해, 아래 두 개의 텐서 t1, t2를 예시로 선언해보겠습니다.
    t1 = torch.tensor([[1, 2],
                       [3, 4]])
    t2 = torch.tensor([[5, 6],
                       [7, 8]])
    
  • 이 때, torch.cat()의 동작은 다음과 같습니다.
    >>> torch.cat((t1, t2), dim=0) # dim=0인 경우
    tensor([[1, 2],
            [3, 4],
            [5, 6],
            [7, 8]])
    
    >>> torch.cat((t1, t2), dim=1) # dim=1인 경우
    tensor([[1, 2, 5, 6],
            [3, 4, 7, 8]])
    
  • torch.stack()은 다음과 같습니다.
    >>> torch.stack((t1, t2))
    tensor([[[1, 2],
             [3, 4]],
     
            [[5, 6],
             [7, 8]]])
    

더 알아보기