공식 홈페이지와 StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.
다음 링크에서 원문을 함께 찾아보실 수 있습니다.
질문
torch.cat()
과torch.stack()
은 어떤 차이점이 있나요?
공식 홈페이지와 StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.
다음 링크에서 원문을 함께 찾아보실 수 있습니다.
torch.cat()
과 torch.stack()
은 어떤 차이점이 있나요?torch.cat()
은 주어진 차원을 기준으로 주어진 텐서들을 붙입(concatenate)니다.torch.stack()
은 새로운 차원으로 주어진 텐서들을 붙입니다.(3, 4)
의 크기(shape)를 갖는 2개의 텐서 A
와 B
를 붙이는 경우,torch.cat([A, B], dim=0)
의 결과는 (6, 4)
의 크기(shape)를 갖고,torch.stack([A, B], dim=0)
의 결과는 (2, 3, 4)
의 크기를 갖습니다.t1
, t2
를 예시로 선언해보겠습니다.t1 = torch.tensor([[1, 2],
[3, 4]])
t2 = torch.tensor([[5, 6],
[7, 8]])
torch.cat()
의 동작은 다음과 같습니다.>>> torch.cat((t1, t2), dim=0) # dim=0인 경우
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
>>> torch.cat((t1, t2), dim=1) # dim=1인 경우
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
torch.stack()
은 다음과 같습니다.>>> torch.stack((t1, t2))
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])