本文轉(zhuǎn)自徐飛翔的“一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss”
版權(quán)聲明:本文為博主原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請附上原文出處鏈接和本聲明。
ranking loss函數(shù):度量學習
不像其他損失函數(shù),比如交叉熵損失和均方差損失函數(shù),這些損失的設(shè)計目的就是學習如何去直接地預(yù)測標簽,或者回歸出一個值,又或者是在給定輸入的情況下預(yù)測出一組值,這是在傳統(tǒng)的分類任務(wù)和回歸任務(wù)中常用的。ranking loss的目的是去預(yù)測輸入樣本之間的相對距離。這個任務(wù)經(jīng)常也被稱之為度量學習(metric learning)。
在訓練集上使用ranking loss函數(shù)是非常靈活的,我們只需要一個可以衡量數(shù)據(jù)點之間的相似度度量就可以使用這個損失函數(shù)了。這個度量可以是二值的(相似/不相似)。比如,在一個人臉驗證數(shù)據(jù)集上,我們可以度量某個兩張臉是否屬于同一個人(相似)或者不屬于同一個人(不相似)。通過使用ranking loss函數(shù),我們可以訓練一個CNN網(wǎng)絡(luò)去對這兩張臉是否屬于同一個人進行推斷。(當然,這個度量也可以是連續(xù)的,比如余弦相似度。)
在使用ranking loss的過程中,我們首先從這兩張(或者三張,見下文)輸入數(shù)據(jù)中提取出特征,并且得到其各自的嵌入表達(embedded representation,譯者:見[1]中關(guān)于數(shù)據(jù)嵌入的理解)。然后,我們定義一個距離度量函數(shù)用以度量這些表達之間的相似度,比如說歐式距離。最終,我們訓練這個特征提取器,以對于特定的樣本對(sample pair)產(chǎn)生特定的相似度度量。
盡管我們并不需要關(guān)心這些表達的具體值是多少,只需要關(guān)心樣本之間的距離是否足夠接近或者足夠遠離,但是這種訓練方法已經(jīng)被證明是可以在不同的任務(wù)中都產(chǎn)生出足夠強大的表征的。
ranking loss的表達式
正如我們一開始所說的,ranking loss有著很多不同的別名,但是他們的表達式卻是在眾多設(shè)置或者場景中都是相同的并且是簡單的。我們主要針對以下兩種不同的設(shè)置,進行兩種類型的ranking loss的辨析
- 使用一對的訓練數(shù)據(jù)點(即是兩個一組)
- 使用三元組的訓練數(shù)據(jù)點(即是三個數(shù)據(jù)點一組)
這兩種設(shè)置都是在訓練數(shù)據(jù)樣本中進行距離度量比較。
成對樣本的ranking loss
在這個設(shè)置中,由訓練樣本中采樣到的正樣本和負樣本組成的兩種樣本對作為訓練輸入使用。正樣本對(?,
?)由兩部分組成,一個錨點樣本
?和 另一個和之標簽相同的樣本
,這個樣本
與錨點樣本在我們需要評價的度量指標上應(yīng)該是相似的(經(jīng)常體現(xiàn)在標簽一樣);負樣本對
由一個錨點樣本
?和一個標簽不同的樣本
組成,
?在度量上應(yīng)該和
不同。(體現(xiàn)在標簽不一致)
現(xiàn)在,我們的目標就是學習出一個特征表征,這個表征使得正樣本對中的度量距離盡可能的小,而在負樣本對中,這個距離應(yīng)該要大于一個人為設(shè)定的超參數(shù)——閾值
。成對樣本的ranking loss強制樣本的表征在正樣本對中擁有趨向于0的度量距離,而在負樣本對中,這個距離則至少大于一個閾值。用
分別表示這些樣本的特征表征,我們可以有以下的式子:
對于正樣本對來說,這個loss隨著樣本對輸入到網(wǎng)絡(luò)生成的表征之間的距離的減小而減少,增大而增大,直至變成0為止。
對于負樣本來說,這個loss只有在所有負樣本對的元素之間的表征的距離都大于閾值 的時候才能變成0。當實際負樣本對的距離小于閾值的時候,這個loss就是個正值,因此網(wǎng)絡(luò)的參數(shù)能夠繼續(xù)更新優(yōu)化,以便產(chǎn)生更適合的表征。這個項的loss最大值不會超過
,在
的時候取得。這里設(shè)置閾值的目的是,當某個負樣本對中的表征足夠好,體現(xiàn)在其距離足夠遠的時候,就沒有必要在該負樣本對中浪費時間去增大這個距離了,因此進一步的訓練將會關(guān)注在其他更加難分別的樣本對中。
假設(shè)用分別表示樣本對兩個元素的表征,
是一個二值的數(shù)值,在輸入的是負樣本對時為0,正樣本對時為1,距離
是歐式距離,我們就能有最終的loss函數(shù)表達式:
三元組樣本對的ranking loss
三元組樣本對的ranking loss稱之為triplet loss。在這個設(shè)置中,與二元組不同的是,輸入樣本對是一個從訓練集中采樣得到的三元組。這個三元組 由一個錨點樣本
?,一個正樣本
?,一個負樣本
組成。其目標是錨點樣本與負樣本之間的距離
與錨點樣本和正樣本之間的距離
之差大于一個閾值
,可以表示為:
在訓練過程中,對于一個可能的三元組,我們的triplet loss可能有三種情況:
- “簡單樣本”的三元組(easy triplet):
。在這種情況中,在嵌入空間(譯者:指的是以嵌入特征作為表征的歐幾里德空間,空間的每個基底都是一個特征維,一般是賦范空間)中,對比起正樣本來說,負樣本和錨點樣本已經(jīng)有足夠的距離了(即是大于
)。此時loss為0,網(wǎng)絡(luò)參數(shù)將不會繼續(xù)更新。
- “難樣本”的三元組(hard triplet):
。在這種情況中,負樣本比起正樣本,更接近錨點樣本,此時loss為正值(并且比
大),網(wǎng)絡(luò)可以繼續(xù)更新。
- “半難樣本”的三元組(semi-hard triplet):
。在這種情況下,負樣本到錨點樣本的距離比起正樣本來說,雖然是大于后者,但是并沒有大于設(shè)定的閾值
,此時loss仍然為正值,但是小于
,此時網(wǎng)絡(luò)可以繼續(xù)更新。
負樣本的挑選
在訓練中使用Triplet loss的一個重要選擇就是我們需要對負樣本進行挑選,稱之為負樣本選擇(negative selection)或者三元組采集(triplet mining)。選擇的策略會對訓練效率和最終性能結(jié)果有著重要的影響。一個明顯的策略就是:簡單的三元組應(yīng)該盡可能被避免采樣到,因為其loss為0,對優(yōu)化并沒有任何幫助。
第一個可供選擇的策略是離線三元組采集(offline triplet mining),這意味著在訓練的一開始或者是在每個世代(epoch)之前,就得對每個三元組進行定義(也即是采樣)。第二種策略是在線三元組采集(online triplet mining),這種方案意味著在訓練中的每個批次(batch)中,都得對三元組進行動態(tài)地采樣,這種方法經(jīng)常具有更高的效率和更好的表現(xiàn)。
然而,最佳的負樣本采集方案是高度依賴于任務(wù)特性的。但是在本篇文中不會在此深入討論,因為本文只是對ranking loss的不同別名的綜述并且討論而已。可以參考[2]以對負樣本采樣進行更深的了解。
ranking loss的別名們~名兒可真多啊
ranking loss家族正如以上介紹的,在不同的應(yīng)用中都有廣泛應(yīng)用,然而其表達式都是大同小異的,雖然他們在不同的工作中名兒并不一致,這可真讓人頭疼。在這里,我嘗試對為什么采用不同的別名,進行解釋:
- ranking loss:這個名字來自于信息檢索領(lǐng)域,在這個應(yīng)用中,我們期望訓練一個模型對項目(items)進行特定的排序。比如文件檢索中,對某個檢索項目的排序等。
- Margin loss:這個名字來自于一個事實——我們介紹的這些loss都使用了邊界去比較衡量樣本之間的嵌入表征距離,見Fig 2.3
- Contrastive loss:我們介紹的loss都是在計算類別不同的兩個(或者多個)數(shù)據(jù)點的特征嵌入表征。這個名字經(jīng)常在成對樣本的ranking loss中使用。但是我從沒有在以三元組為基礎(chǔ)的工作中使用這個術(shù)語去進行表達。
- Triplet loss:這個是在三元組采樣被使用的時候,經(jīng)常被使用的名字。
- Hinge loss:也被稱之為max-margin objective。通常在分類任務(wù)中訓練SVM的時候使用。他有著和SVM目標相似的表達式和目的:都是一直優(yōu)化直到到達預(yù)定的邊界為止。
Siamese 網(wǎng)絡(luò)和 Triplet網(wǎng)絡(luò)
Siamese網(wǎng)絡(luò)(Siamese Net)和Triplet網(wǎng)絡(luò)(Triplet Net)分別是在成對樣本和三元組樣本 ranking loss采用的情況下訓練模型。
Siamese網(wǎng)絡(luò)
這個網(wǎng)絡(luò)由兩個相同并且共享參數(shù)的CNN網(wǎng)絡(luò)(兩個網(wǎng)絡(luò)都有相同的參數(shù))組成。這些網(wǎng)絡(luò)中的每一個都處理著一個圖像并且產(chǎn)生對于的特征表達。這兩個表達之間會進行比較,并且計算他們之間的距離。然后,一個成對樣本的ranking loss將會作為損失函數(shù)進行訓練模型。
我們用 表示這個CNN網(wǎng)絡(luò),我們有Siamese網(wǎng)絡(luò)的損失函數(shù)如:
Triplet網(wǎng)絡(luò)
這個基本上和Siamese網(wǎng)絡(luò)的思想相似,但是損失函數(shù)采用了Triplet loss,因此整個網(wǎng)絡(luò)有三個分支,每個分支都是一個相同的,并且共享參數(shù)的CNN網(wǎng)絡(luò)。同樣的,我們能有Triplet網(wǎng)絡(luò)的損失函數(shù)表達為:
在多模態(tài)檢索中使用ranking loss
根據(jù)我的研究,在涉及到圖片和文本的多模態(tài)檢索任務(wù)中,使用了Triplet ranking loss。訓練數(shù)據(jù)由若干有著相應(yīng)文本標注的圖片組成。任務(wù)目的是學習出一個特征空間,模型嘗試將圖片特征和相對應(yīng)的文本特征都嵌入到這個特征空間中,使得可以將彼此的特征用于在跨模態(tài)檢索任務(wù)中(舉個例子,檢索任務(wù)可以是給定了圖片,去檢索出相對應(yīng)的文字描述,那么既然在這個特征空間里面文本和圖片的特征都是相近的,體現(xiàn)在距離近上,那么就可以直接將圖片特征作為文本特征啦~當然實際情況沒有那么簡單)。為了實現(xiàn)這個,我們首先從孤立的文本語料庫中,學習到文本嵌入信息(word embeddings),可以使用如同Word2Vec或者GloVe之類的算法實現(xiàn)。隨后,我們針對性地訓練一個CNN網(wǎng)絡(luò),用于在與文本信息的同一個特征空間中,嵌入圖片特征信息。
具體來說,實現(xiàn)這個的第一種方法可以是:使用交叉熵損失,訓練一個CNN去直接從圖片中預(yù)測其對應(yīng)的文本嵌入。結(jié)果還不錯,但是使用Triplet ranking loss能有更好的結(jié)果。
使用Triplet ranking loss的設(shè)置如下:我們使用已經(jīng)學習好了文本嵌入(比如GloVe模型),我們只是需要學習出圖片表達。因此錨點樣本是一個圖片,正樣本
是一個圖片對應(yīng)的文本嵌入,負樣本
是其他無關(guān)圖片樣本的對應(yīng)的文本嵌入。為了選擇文本嵌入的負樣本,我們探索了不同的在線負樣本采集策略。在多模態(tài)檢索這個問題上使用三元組樣本采集而不是成對樣本采集,顯得更加合乎情理,因為我們可以不建立顯式的類別區(qū)分(比如沒有l(wèi)abel信息)就可以達到目的。在給定了不同的圖片后,我們可能會有需要簡單三元組樣本,但是我們必須留意與難樣本的采樣,因為采集到的難負樣本有可能對于當前的錨點樣本,也是成立的(雖然標簽的確不同,但是可能很相似。)
在該實驗設(shè)置中,我們只訓練了圖像特征表征。對于第個圖片樣本,我們用
表示這個CNN網(wǎng)絡(luò)提取出的圖像表征,然后用
分別表示正文本樣本和負文本樣本的GloVe嵌入特征表達,我們有:
在這種實驗設(shè)置下,我們對比了triplet ranking loss和交叉熵損失的一些實驗的量化結(jié)果。我不打算在此對實驗細節(jié)寫過多的筆墨,其實驗細節(jié)設(shè)置和[4,5]一樣。基本來說,我們對文本輸入進行了一定的查詢,輸出是對應(yīng)的圖像。當我們在社交網(wǎng)絡(luò)數(shù)據(jù)上進行半監(jiān)督學習的時候,我們對通過文本檢索得到的圖片進行某種形式的評估。采用了Triplet ranking loss的結(jié)果遠比采用交叉熵損失的結(jié)果好。
深度學習框架中的ranking loss層
Caffe
- Constrastive loss layer
- pycaffe triplet ranking loss layer
PyTorch
- CosineEmbeddingLoss
- MarginRankingLoss
- TripletMarginLoss
TensorFlow
- contrastive_loss
- triplet_semihard_loss
Reference
[1]. https://blog.csdn.net/LoseInVain/article/details/88373506
[2]. https://omoindrot.github.io/triplet-loss
[3]. https://github.com/adambielski/siamese-triplet
[4]. https://arxiv.org/abs/1901.02004
[5]. https://gombru.github.io/2018/08/01/learning_from_web_data/