안녕하세요, 오늘은 PyTorch를 쓰다보면 굉장히 많이 접하게 되는 nn.Linear 함수에 대해 설명하려고 합니다.
nn.Linear란
먼저, 논문을 읽거나 기타 자료를 볼 때 nn.Linear는 Fully Connected Layer(FC) 또는 Dense Layer라고 표현을 합니다. Fully Connected, Dense와 같은 단어에서 느낌이 오듯, Input/Output의 노드를 빼먹지 않고 촘촘하게 연결한 구조를 갖습니다. 그렇다고 Input 또는 Output 노드 사이의 연결이 존재하지는 않고, Bipartite Graph를 떠올리시면 쉽습니다 (참고).
nn.Linear란 어떤 역할을 할까요?
- Classification: 분류 문제의 경우, 보통 네트워크의 가장 마지막에 붙어 이전까지 얻었던 Feature를 바탕으로 분류를 수행합니다.
- Regression: 위와 사용 방법은 똑같은데 분류처럼 이산적인(discrete) 결과가 아닌 연속적인(continuous) 값을 얻고자 할 때 사용합니다.
- Feature Extraction: Input data의 특징을 얻을 때도 많이 씁니다. 요즘은 조금 덜하기는 하지만 클러스터링(clustering), 차원 축소(dimensionality reduction), 아니면 임베딩 시각화(visualization)를 위해 사용하기도 합니다.
- Transfer Learning: 어떤 모델이 이미 다른 데이터로 학습되어 있는데, 우리는 우리의 문제를 위해 새로이 학습하고 싶을 때가 있죠? 이럴 때 기존 모델의 맨 마지막 FC를 빼고 우리의 문제에 맞게 새로운 FC를 끼워넣어 학습하기도 합니다.
어찌 이것만 있겠냐마는 이런 상황에서 자주 사용합니다.
코드와 함께 살펴보는 nn.Linear
직접 코드와 함께 보도록 하겠습니다.
import torch
import torch.nn as nn
backbone = featureExtractor() # Feature를 얻기 위한 네트워크, 편의상 output dimension = (1, 1024)
fc = nn.Linear(1024, 10) # 1024개의 input 노드를 받아 10개의 output 노드를 얻기 위함
feature = backbone(input) # input의 feature를 얻고
output = fc(feature) # (1, 10) --> 10가지 class의 분류 문제를 풀겠다!
아주아주 간단한 예시인데요, 10가지 클래스를 가진 분류 문제를 풀기 위한 방법입니다.
실제로는 동작하는 코드가 아니지만 감만 익혀가시면 좋겠습니다.
여기서 가장 중요한 포인트는 왜 nn.Linear의 parameters로 (1024, 10)이 들어가는지 입니다.
학창 시절에 수학 행렬을 배우셨다면 이해가 빠를텐데요, 이 안에서는 결국 행렬 연산을 한다고 생각하시면 되겠습니다. 아주 단순한 행렬곱 연산입니다.
$$ (M, P) \times (P, N) = (M, N) $$
PyTorch Official Docs를 보시면 input / output 차원이 각각
- Input: \((*, \mathit{H_{in}})\)
- Output: \((*, \mathit{H_{out}})\)
여기서 \(*\)는 그 어떤 차원도 가능합니다. 예를 들어 Input 차원이 \((2, 3, 4, 5, 6, 7, 8, \mathit{H_{in}})\) 이라도 상관 없이 Output 차원은 \((2, 3, 4, 5, 6, 7, 8, \mathit{H_{out}})\) 이 됩니다.
출처: https://www.sharetechnote.com/html/Python_PyTorch_nn_Linear_01.html
'기술 이야기 > PyTorch' 카테고리의 다른 글
[PyTorch] TorchScript: Tracing vs. Scripting (0) | 2023.06.04 |
---|---|
PyTorch 2.0 vs ONNX vs TensorRT 비교 (4) | 2023.05.19 |
[PyTorch] nn.Module에 대한 이해 (0) | 2023.04.03 |
[PyTorch] 간략한 자기 소개 (0) | 2023.03.30 |
댓글