📜 Paper/DeepLearning

Deep metric learning using Triplet network

Ju_pyter 2024. 2. 21. 10:58

논문 링크https://arxiv.org/abs/1412.6622

 

Triplet Network

Triplet Network에서는 facenet과 같이 3개의 데이터가 입력으로 주어짐.

X : anchor

X- : anchor와 일치하지 않는 입력

X+ : anchor와 같은 클래스에 해당하는 입력 값

이 세개의 인스턴스는 각각 동일한 weight를 공유하고 있는 feed forward 네트워크(샴 네트워크)에 들어가게 된다.

 

여기서 잠깐, Siamese Network(샴 네트워크)에서 “샴’은 샴쌍둥이에서 유래된 것으로, 즉, weight를 공유하는 두 네트워크로 이루어진 네트워크를 의미한다.

왼쪽 사진처럼 Weight를 공유하는 두개의 네크워크를 샴 네트워크라고 하는데 어차피 weight를 공유하기 때문에, 한 네트워크라고 봐도 무방함.

 

다시 Triplet Network로 넘어와서..

X를 기준으로 X와 일치하는 클래스를 가진 X+와 일치하지 않는 클래스를 가진 X-가 각각의 네트워크를 통과하면 하나의 임베딩된 중간 결과가 나오는데, 각각은 anchor에 해당하는 X와의 L2 distance를 의미한다. 즉, 위 수식에서 각각의 값들은 X와의 L2 distance에 해당한다. (위에 있는 값이 X와 X-의 L2 distance, 아래가 X+와의 L2 distance를 의미하며, 그 distance들을 가진 [[2], [5]]와 같은 형식의 벡터 값이 나옴)

 

Training

각각의 임베딩 값은 Softmax function을 거치게 됨.

기존의 CNN과 비슷하게 학습은 간단한 이진 분류의 negative-log-likelihood loss를 사용하였으며, 이를 최소하기 위한 기법으로는 SGD를 사용.

그러나, 추후 softmax function 결과값에 대한 Loss함수를 MSE로 대체하면서 더 좋은 결과를 얻게됨.

이를 바탕으로 아래 수식을 살펴보면..

d+, d-값은 L2 distance에 대해서 softmax function을 취한 값을 의미하게 되고, Loss는 이 둘을 MSE로 구한 값이 된다.

 

Results

정리하자면 해당 논문에서는 Siamese network와 같이 Euclidean distance 방법을 이용하여 유사도를 측정함.

하지만 두 개의 데이터를 비교해 얻은 contrastive loss로 학습하는 Siamese network와 달리, 주어진 데이터를 같은 class에 속하는 데이터, 다른 class에 속하는 데이터와 함께 세 쌍(triplet)으로 묶어 학습했다는 점에서 차이를 보임

그 결과 Mnist 데이터셋에서 Siamese Network보다 1.5%정도 더 높은 성능을 보였을 뿐만 아니라 나머지 3개의 데이터셋에서 유의미한 결과를 얻지 못한 Siamese Network와 달리 TripletNet은 각각 87%, 95%, 70%라는 준수한 수치를 보임.