(CNN) Pytorch Tensor - Indexing
- 최초 작성일: 2025년 6월 11일 (수)
목차
기본 indexing
파이토치 텐서의 indexing 방법은 NumPy 배열과 매우 유사하다. 단일 지정 인덱싱을 하면 원본 텐서의 차원이 하나 줄어든 텐서가 반환된다.
import torch
# 텐서 생성
ts_01 = torch.arange(0, 10).view(2, 5)
print(ts_01)
출력 결과:
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
단일 지정 인덱싱 예시:
print('ts_01[0, 0]:', ts_01[0, 0], 'ts_01[0, 1]:', ts_01[0, 1])
print('ts_01[1, 0]:', ts_01[1, 0], 'ts_01[1, 2]:', ts_01[1, 2])
print(ts_01[0, 0].shape, ts_01[0, 0].ndim, ts_01[0, :].shape, ts_01[0, :].ndim)
출력 결과:
ts_01[0, 0]: tensor(0) ts_01[0, 1]: tensor(1)
ts_01[1, 0]: tensor(5) ts_01[1, 2]: tensor(7)
torch.Size([]) 0 torch.Size([5]) 1
슬라이싱(slicing) indexing
슬라이싱을 사용하면 원본 텐서의 차원이 유지된다.
print('ts_01[0, :]은', ts_01[0, :], '\nts_01[:, 0]은', ts_01[:, 0])
print('ts_01[0, 0:3]은', ts_01[0, 0:3], '\nts_01[1, 1:4]은', ts_01[1, 1:4])
print('ts_01[:, :]\n', ts_01[:, :])
출력 결과:
ts_01[0, :]은 tensor([0, 1, 2, 3, 4])
ts_01[:, 0]은 tensor([0, 5])
ts_01[0, 0:3]은 tensor([0, 1, 2])
ts_01[1, 1:4]은 tensor([6, 7, 8])
ts_01[:, :]
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
Fancy(List) indexing
Fancy indexing은 지정한 인덱스 목록을 사용해 텐서의 특정 행을 추출하는 방법이다.
torch.manual_seed(2025)
random_indexes = torch.randint(0, 5, size=(4,))
print('random_indexes:', random_indexes)
ts_01 = torch.rand(size=(10, 5))
print('ts_01:\n', ts_01)
ts_01_1 = ts_01[random_indexes]
print('Fancy indexing 결과 ts_01_1:\n', ts_01_1)
출력 결과:
random_indexes: tensor([4, 2, 4, 0])
ts_01:
tensor([[...], [...], [...], [...], [...]])
Fancy indexing 결과 ts_01_1:
tensor([[...], [...], [...], [...]])
Boolean indexing
Boolean indexing은 조건에 따라 원소를 선택한다. NumPy 배열과 다르게 PyTorch는 Boolean indexing 결과가 1차원 텐서로 반환된다.
ts_01 = torch.arange(0, 10).view(2, 5)
print(ts_01)
mask = ts_01 > 4
print(mask)
print('Boolean indexing 결과:', ts_01[mask])
출력 결과:
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
tensor([[False, False, False, False, False],
[ True, True, True, True, True]])
Boolean indexing 결과: tensor([5, 6, 7, 8, 9])
torch.where 활용
torch.where
를 사용하면 원본 텐서의 차원을 유지하며 조건에 따라 값을 치환할 수 있다.
print(torch.where(ts_01 > 4, input=ts_01, other=torch.tensor(999)))
출력 결과:
tensor([[999, 999, 999, 999, 999],
[ 5, 6, 7, 8, 9]])
이 문서를 통해 PyTorch 텐서의 다양한 indexing 방법과 NumPy 배열과의 차이를 명확하게 이해할 수 있다.