본문 바로가기
기술 이야기/논문 리뷰

Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (I-JEPA) 논문 리뷰

by 넌 꿈이 뭐야? 2023. 6. 19.

안녕하세요, 오늘은 I-JEPA라는 Self-Supervised Learning 기법을 소개합니다. 이 방법은 휘황찬란한 Augmentation 기법에 구애받지 않으면서 이미지의 픽셀을 들여다보는 것이 아닌 기존 보다 의미론적 관점에서 조금 더 직관적인 이미지 이해를 위한 설계를 통해 훌륭한 결과를 달성했다고 합니다. 이번 설명은 결론부터 먼저 말씀드리는 방법으로 진행해보겠습니다.


I-JEPA 정리

이 논문은 기존 Self-supervised Learning 방법들이 픽셀 수준에서 학습했던 것과는 달리, 모델이 좀 더 고차원적인 이해를 하도록 Latent space 수준의 학습을 통해

  • 특정 Downstream tasks에만 편향되지 않는 Image augmentation을 생략 가능하게 하고
  • 기존 방법 대비 빠른 학습

을 가능하게 했다고 합니다.

I-JEPA Pipeline

왜 이런 방법을 채택했는지, 구체적으로 어떤 과정이 포함되어 있는지, 결과는 어떠한지 하나씩 살펴보도록 하겠습니다.

Introduction

현재까지의 이미지 Self-supervised Learning은 크게 Invariance-based, Generative-based 두 가지 방법으로 나눌 수 있습니다.

Invariance-based Pretraining Method

이 방법은 같은 대상을 이렇게도 보고, 저렇게 봐도 결국은 같은 대상임을 학습하는 방법입니다. 아래 그림의 리트리버를 보시면 여러 각도와 자세를 취하더라도 결국 동일한 리트리버라는 사실은 변하지 않습니다.

Invariance - 출처: https://www.javatpoint.com/pytorch-data-augmentation-process

이 밖에도 색상을 바꾸거나 일부분을 자르고, 종횡비를 변형시키는 등의 Augmentation 작업이 포함되는데 CutMixSimCLR 같은 연구가 매우 유명합니다.

단점

하지만 이 방법은 High semantic representation을 만들 수 있지만 학습 데이터에 많이 편향되어서(biased) Downstream tasks에 대해 기대했던 만큼 좋은 성능을 거두지 못했습니다. 예를 들어 Image classification은 high semantic이 중요하다고 생각되지만 Image segmentation은 아무래도 경계를 잘 구분하는 일도 중요하기 때문에 low semantic 또한 많이 고려되야 하지 않을까요?

Generative Pretraining Method

그러면 Generative Model은 어떨까요? Masked Autoencoder(MAE)가 매우 유명한 방법 중 하나입니다. 이미지 곳곳을 마스크로 가려놓고 가려놓은 부분을 복원하는 방법은 다양한 Downstream tasks에서 훌륭한 성능을 보여줬습니다. 특히 Object detection, Semantic segmentation에서의 성능 향상이 인상적이었는데요,

단점

하지만 이 방법도 곰곰히 생각해보면 픽셀 수준에서의 복원을 목표로 하기 때문에 linear-probing과 같은 high semantic을 요하는 태스크에서 살짝 아쉬웠습니다. 또한 Masked Siamese Networks 논문에 의하면 적은 수의 데이터로 finetune 하려고 하면 overfitting이 발생하기 쉽다고 합니다. 아무래도 edge, color와 같은 low semantic을 학습하기 때문에 문제의 다양성이 떨어져서 쉽게 overfitting 되는걸까요?

그래서 I-JEPA의 장점은

위의 그림처럼 Invariance-based(a), Generative(b)의 장점을 합쳐 Joint-Embedding Predictive Architecture(JEPA)를 구성했더니

  • Strong off-the-shelf representation: 복잡한 Augmentation 없이도 특정 태스크에 편향되지 않는 아주 일반적인 representation을 얻을 수 있었다고 합니다.
  • Fast Training: 아무래도 Latent space를 학습을 loss function으로 하기 때문에 굉장히 빠르다고 합니다. iBOT보다 2.5배 빠르고, MAE보다 10배나 빠르게 학습 가능하다고 합니다.
한마디로 Image가 아니고 Latent space를 복원하는 학습을 하겠다

Methods

제가 보기엔 그리 특별한 아이디어는 아니지만 모름지기 논문은 논리와 결과로 승부하는 법입니다. 결과가 궁금하기 때문에 빠르게 방법을 설명하도록 하겠습니다.

 

I-JEPA의 구조에는 3가지 핵심이 존재합니다.

  • Targets: 어느 부분을 복원 학습할 것인가 (s가 붙은 이유는 그 부분이 여러 군데라서)
  • Context: 모델에게 사전정보로써 어느 부분을 보여줄 것인가
  • Prediction: 복원 학습을 하는 네트워크

Targets

Target은 이미지의 어느 부분을 복원하는 학습을 진행할 지 결정합니다.

위의 그림은 Image Patch를 25개 갖는 예시입니다 (위는 예시일 뿐, 일반적으로 ViT-H/14 이런 식으로 \(224\times224\)이미지의 가로세로를 14등분하여 총 \(16\times16=256\)개의 Patch를 갖습니다).

  1. 이미지를 \(N\)등분하여 \(N\)개의 Patch로 나눔
  2. Target Encoder(ViT)에 입력으로 넣어 Patch-level representation \(\mathbf{s}_{y}\)을 얻음
  3. 그리고 \(M\)개의 target blocks \(\mathit{B}_{i}\)를 임의로 생성

이렇게 되면 그림과 같이 \(M=4\)개의 블록이 있다고 할 때 각 블록이 포함하는 패치의 갯수나 위치를 표현할 수 있습니다.

1번 블록(빨간색)은 2, 3, 7, 8번째 패치를 포함하고, 2번 블록(주황색)은 9, 10번 패치를 포함하게 됩니다.

 

논문에서는 4개의 블록을 만들었으며 Target encoder \(\mathit{f}_{\bar{\theta}}\)는 직접 학습하지 않고 후술할 Context encoder의 Exponential Moving Average(EMA, 설명 블로그)를 활용합니다. 아래는 I-JEPA 코드 일부입니다.

with torch.no_grad():
    m = next(momentum_scheduler)
    for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):
        param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)

(참고로 EMA는 모델이 특정 mini-batch에 큰 영향을 받지 않도록 Moving Average를 적용하는 방식인데, 앞에 Exponential이 붙음으로 인해 최근 mini-batch에 많은 가중치를 주는 전략입니다)

Context

Context는 Target을 복원하기에 앞서 어떤 정보를 줄 것인가를 결정합니다. 방법은 아래와 같습니다

  1. 이미지에 대해 block \(x\)를 하나 설정하여 해당 block 이외의 영역을 마스킹
  2. Target blocks \(\mathit{B}_{i}\) 영역도 모두 마스킹
  3. 남은 부분을 Context encoder \(\mathit{f}_{\theta}\)에 입력으로 넣음
  4. Patch-level context representation \(\mathbf{s}_{x}\)를 얻음

위 그림의 맨 오른쪽을 보면, 검은색 마스킹을 제외한 구역이 context가 됩니다.

Prediction

이제 우리는 어디를 복원할 지(Target)와 어떤 사전정보를 줄 지(Context) 모두를 얻었습니다. 이제 Predictor를 통해서 학습을 수행하면 되는데요, 위의 그림은 1, 2번 target을 학습하는 과정을 각각 표현한 것이고 3, 4번(파랑, 초록) target 학습과정은 쉽게 상상하실 수 있을 것 같습니다.

그리고 학습 과정은 아래와 같습니다.

  1. 각 Target \(\mathbf{s}_{y}(i)\)에 대해 아래를 수행
    1. Context에다가 \(\mathbf{s}_{y}(i)\)에 해당하는 부분을 학습 가능한 mask token으로 하여 더함 \((\mathbf{s}_{x}, \left \{\mathbf{m}_{j}\right \}_{j\in B_{i}})\)
    2. 위에서 합쳐 만든 것을 Predictor \(g_{\phi}\)에 넣어 \(\mathbf{\hat{s}}_{y}(i)\)를 얻음
    3. Target에 해당 하는 패치만 이용해 모델을 학습. Latent vector 간의 \(L_{2}\) loss를 구한다.

Results

결과는 어떨까요? 

이 방법은 Introduction에서 소개했던 두가지 방법(Invariance-based, Generative)의 딱 중간지점에 있는 방법이라고 생각하는데, 결론은 아래 두가지 중 하나일 것으로 예상됩니다.

  • 대박: High semantic(Image classification) / Low semantic(Reconstruction) 모두 좋은 결과를 보여주거나
  • 쪽박: 양쪽 모두 이도 저도 아니게 되거나

Global Prediction Tasks - Linear Probing

I-JEPA Linear-evaluation(probing) 결과

아래는 I-JEPA의 기존 방법들 대비 갖는 장점과 제 개인적인 생각입니다.

  • Augmentation 사용하지 않은 집단: 그들과 비교했을 때 적은 Epochs으로도 높은 Top-1 성능을 보여줌
    • 하지만 CAE ViT-L/16과의 비교에서 600 Epochs 결과를 보여주지 않고 좀 더 많은 학습을 하더라도 78.1을 뛰어넘는 숫자를 표기했다면 어땠을까요?
  • Augmentation 사용한 집단: Augmentation을 사용하지 않고도 그들과 비견될 성능을 보여줌

결론적으로 High Semantic Learning이 잘 되었다는 것을 보여주었습니다.

Local Prediction Tasks - Object Counting / Depth Prediction

이 다음은 Pixel-level Understanding을 얼마나 잘하는지 살펴보겠습니다.

위는 CLEVR 데이터셋의 Object Counting(Count), Depth Prediction(Dist) 결과를 비교한 테이블입니다.

이번에는 Augmentation을 사용한 DINO, iBOT을 능가하는 성능을 보여준 것만 같습니다.

하지만 자세히 보면 두 결과는 I-JEPA가 사용한 ViT-H/14 보다 작은 모델(ViT-B/8, ViT-L/16)을 사용했기 때문에 아주 공정한 비교는 아닙니다. 왜 이런 결과를 보여줬는지 잘 모르겠습니다 (혹시 아는 분이 계시다면 댓글 부탁드립니다).

Less Training with Better Performance

어찌보면 적은 학습이 I-JEPA의 가장 큰 장점이 아닐까 싶습니다. 물론 위의 두 결과도 훌륭했지만 어느 면으로 보아도 이견이 없는 적은 학습시간은 이 논문이 갖는 중요한 장점입니다.

Visualization of I-JEPA Predictor

맨 왼쪽은 원본 이미지, 그 다음은 I-JEPA에 들어가는 context image, 그 다음의 4개는 각기 다른 Generative decoder 결과

위의 그림은 I-JEPA의 predictor에서 얻는 representation을 Generative Model의 Decoder에 입력으로 넣어 얻은 결과입니다. 자세히 보면 아래와 같은 결론을 얻을 수 있습니다.

  • Global Understanding: 위치나 포즈가 얼추 나오는 것으로 보아 큰 맥락에서 마스킹 된 부분이 어떻게 채워져야 할 지에 대한 단서를 잘 포함하고 있음
  • Local Understanding: 배경이나 색상의 자연스러움이 떨어지는 것으로 보아 아무래도 디테일을 제대로 포함하고 있는 것 같지는 않음

긴 글 읽어주셔서 감사합니다

반응형

댓글