개발자식

[Pytorch] torch.sum (Tensor 요소 합) 본문

AI/Pytorch

[Pytorch] torch.sum (Tensor 요소 합)

밍츠 2022. 10. 18. 19:25

torch.sum()

- input 텐서에 있는 모든 요소의 합계를 반환한다.

- dim 으로 차원을 압축할 수 있다.

 

a = torch.tensor([[ 1,  2,  3],[ 4,  5,  6]]) #(2,3)

#dim = 0
b = torch.sum(a, dim = 0)

#dim = 1
c = torch.sum(a, dim = 1)

print(b)
print(c)
Output:
tensor([5, 7, 9])
tensor([ 6, 15])

 

여기서 dim을 tuple 형태로 넣을 수 있다.

tuple로 넣으면 차례대로 계산한다.

x = torch.rand(256, 10, 8)

print(torch.sum(x, dim=(2)).shape)
print(torch.sum(x, dim=(2,1)).shape)
Output:
torch.Size([256, 10])
torch.Size([256])

 

https://pytorch.org/docs/stable/generated/torch.sum.html

 

torch.sum — PyTorch 1.12 documentation

Shortcuts

pytorch.org

 

Comments