cat과 stack은 PyTorch의 내장함수로써, 여러 텐서들을 합치는 데에 사용된다. 두 함수의 차이점을 알아보자!
TORCH.CAT
input parameter : tensor, dim(default값은 0)
output : tensor
ex) torch.cat((x, y, z), 1)
alias) torch.concat, torch.concatenate
input으로는 tensor(합쳐야 할 텐서들을 튜플 형태로 작성)과 dim(텐서들을 합칠 차원)이 들어간다.
합치는 tensor들의 크기는 dim을 제외하고 모두 같아야 한다.
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 3)
torch.cat((x, y), dim=0) # error
torch.cat((x, y), dim=1) # error
torch.cat((x, y), dim=2) # return (2, 3, 7) shape tensor
TORCH.STACK
input parameter : tensor, dim(default값은 0)
output : tensor
ex) torch.stack((x, y, z), 1)
사용법이 cat 함수와 동일하다. 그러나 stack 함수는 합치는 tensor의 크기가 dim을 포함하여 모두 같아야 한다.
또한 dim의 범위가 cat 함수가 (-차원~차원-1)이었다면, stack 함수는 (-차원-1~차원)이다.
이유는 아래 코드를 보면 이해할 수 있다.
x = torch.randn(2, 3, 4)
print(torch.cat((x, x, x, x, x), 0).shape) # torch.Size([10, 3, 4])
print(torch.stack((x, x, x, x, x), 0).shape) # torch.Size([5, 2, 3, 4])
print(torch.stack((x, x, x, x, x), 1).shape) # torch.Size([2, 5, 3, 4])
print(torch.stack((x, x, x, x, x), 2).shape) # torch.Size([2, 3, 5, 4])
print(torch.stack((x, x, x, x, x), 3).shape) # torch.Size([2, 3, 4, 5])
cat 함수가 이미 있는 차원에 tensor들을 쌓는 방식이라면, stack 함수는 차원을 새로 만들어 tensor를 쌓는다.
TORCH.VSTACK / TORCH.HSTACK / TORCH.DSTACK
torch.cat으로 구현이 가능한 함수들이다.
x = torch.randn(2, 3, 4)
print(torch.vstack((x, x, x)).shape) # torch.Size([6, 3, 4])
print(torch.hstack((x, x, x)).shape) # torch.Size([2, 9, 4])
print(torch.dstack((x, x, x)).shape) # torch.Size([2, 3, 12])