본문 바로가기

카테고리 없음

[PyTorch] torch.cat VS torch.stack

cat과 stack은 PyTorch의 내장함수로써, 여러 텐서들을 합치는 데에 사용된다. 두 함수의 차이점을 알아보자!

 

TORCH.CAT

input parameter : tensor, dim(default값은 0)

output : tensor

ex) torch.cat((x, y, z), 1)

alias) torch.concat, torch.concatenate

 

input으로는 tensor(합쳐야 할 텐서들을 튜플 형태로 작성)과 dim(텐서들을 합칠 차원)이 들어간다.

합치는 tensor들의 크기는 dim을 제외하고 모두 같아야 한다.

 

x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 3)

torch.cat((x, y), dim=0) # error
torch.cat((x, y), dim=1) # error
torch.cat((x, y), dim=2) # return (2, 3, 7) shape tensor

 

TORCH.STACK

input parameter : tensor, dim(default값은 0)

output : tensor

ex) torch.stack((x, y, z), 1)

 

사용법이 cat 함수와 동일하다. 그러나 stack 함수는 합치는 tensor의 크기가 dim을 포함하여 모두 같아야 한다.

또한 dim의 범위가 cat 함수가 (-차원~차원-1)이었다면, stack 함수는 (-차원-1~차원)이다.

이유는 아래 코드를 보면 이해할 수 있다.

 

x = torch.randn(2, 3, 4)

print(torch.cat((x, x, x, x, x), 0).shape)   # torch.Size([10, 3, 4])
print(torch.stack((x, x, x, x, x), 0).shape) # torch.Size([5, 2, 3, 4])
print(torch.stack((x, x, x, x, x), 1).shape) # torch.Size([2, 5, 3, 4])
print(torch.stack((x, x, x, x, x), 2).shape) # torch.Size([2, 3, 5, 4])
print(torch.stack((x, x, x, x, x), 3).shape) # torch.Size([2, 3, 4, 5])

 

cat 함수가 이미 있는 차원에 tensor들을 쌓는 방식이라면, stack 함수는 차원을 새로 만들어 tensor를 쌓는다.

 

TORCH.VSTACK / TORCH.HSTACK / TORCH.DSTACK

 

torch.cat으로 구현이 가능한 함수들이다.

 

x = torch.randn(2, 3, 4)

print(torch.vstack((x, x, x)).shape) # torch.Size([6, 3, 4])
print(torch.hstack((x, x, x)).shape) # torch.Size([2, 9, 4])
print(torch.dstack((x, x, x)).shape) # torch.Size([2, 3, 12])