WSI影像分類模型介紹

介紹 5 minutes

隨著AI技術的進步,越來越多的領域結合AI以達成更加精準、有效率的研究分析效果,其中一門新興的領域便是結合了AI運算和病理學的計算病理學(computational pathology)。

在病理學中,大量的數據如基因數據、臨床報告、組織切片圖像等能利用機器學習的方式來提取其中對未來研究有價值的相關訊息。而當醫療專家解讀AI處理過的數據時,將能更輕易且明確的解讀醫學數據。有AI加入到病理學研究中協助處理龐大的資料量,將能幫助醫療團隊發現一些可能不易察覺的模式,從而使診斷過程更快速且更準確。

在此將介紹一個用於分析病理影像的AI模型,clustering-constrained-attention multiple-instance learning(CLAM),此模型能從全幅切片影像(Whole slide image, WSI)中自動辨識出具有重要病理特徵的區域,並能夠將對象 WSI 進行自動化的分類。以腫瘤影像來說,CLAM可以簡單的分辨出正常和病變組織,並能夠告訴你模型判斷的依據,這樣的功能除了能夠得到分類的結果以外,還可以進一步的分析判斷依據是否合理,是透過已知的知識?或是根據潛在的病理特徵?

過程中醫師端只需提供WSI檔案(.svs, ,ndpi, .tiff…)或是活體組織檢驗(Biopsies)甚至是顯微影像及其對應的標註即可,而CLAM在不同的病人群體上依然能達到很高的準確度,對於醫師、醫療院所進行研究或臨床試驗也很有幫助。

CLAM

CLAM,全稱clustering-constrained-attention multiple-instance learning。在了解各個步驟的細節之前,可以先從名稱來知道整個模型的概念:

Clustering-constrained,條件性叢集。叢集模型在機器學習中很常見,是一種能自動將輸入的資料進行分類的模型,比如我們將各式各樣的WSI影像輸入模型,模型可能根據了組織的形態差異將其分為一般組織和腫瘤組織,也可能根據染劑色調而將不同染色實驗數據分類。而條件性的叢集模型則能讓我們透過人工標註讓機器學習特定的類別,當應用於病理影像時便能讓醫師團隊根據想研究的特徵自行進行標註病變特徵、組織結構或細胞活動等。

Attention,注意力模型。就像人類在看圖像時,大腦會自動集中注意力在重要的地方,注意力模型可以自動分析輸入影像,並在訓練時著重加強「關鍵」區域。這種模型能夠標註影像中最重要的部分,特別適用於WSI影像的訓練。由於WSI影像中組織的分佈不均且結構複雜,注意力模型可以幫助模型聚焦於真正關鍵的組織或細胞,提升影像分析的準確性。

Multiple-instance learning (MIL),多實例學習。MIL 是一種常用於病理影響分析的技術,由於全幅切片影像(WSI)的高解析度和巨大尺寸,直接使用整個影像進行訓練非常困難。因此在訓練模型時,通常會將WSI切分成大量的小區塊再輸入模型中,而將這些切片中的各個細胞單獨人工標記是項浩大的工程。透過MIL,醫師將不需要針對單一細胞進行標記,可以直接給予整個WSI標記即可開始訓練,這將大幅的減少醫師端所需要付出的時間及勞力。除了省時,MIL對於WSI中特徵組織的不確定性也有所幫助,由於同一張WSI的所有區塊共享相同的標記,這能夠提升對非典型組織的判斷準確度。而不同於一般只能做二分法的MIL模型,CLAM則開發出了多類別分類的MIL,這使得模型的用更加的廣泛。

模型運作流程

CLAM的整體運作流程如下圖: download

如上圖所示,在進入到訓練流程之前,最重要的其實是對影像的處理,尤其是針對高解析度和複雜度的WSI影像。

CLAM 模型的第一步是對組織切片進行處理,去除背景保留組織,並依據模型的需求將WSI影像切分成256*256像素的小區塊。接著這些切片輸入卷積神經網路(CNN)中。CNN是一種常用於影像辨識的神經網路模型,在讀取影像時,CNN中的卷積核會以矩陣的形式小塊小塊的掃過整張影像並讀取其中的訊息,且根據型態、顏色等變化保留模型認為最重要的部分,也就是所謂的特徵。透過卷積核多次的提取特徵和對特徵矩陣的降維,我們可以把資訊量龐大的原始切片,轉換為較為低維的特徵向量。

在完成前置處理後,模型會開始訓練。如上述,每一張WSI已在前置作業中被切為k張切片,而每一張切片都有從CNN裡萃取出來的特徵向量zk。在每一回合訓練時,會隨機選取部分的切片作為訓練資料。在CLAM的第一層W1中,首先將zk降維至512維的向量hk,以連結後續的注意力網路。

在CLAM中,注意力網路由多個全連結層所構成,全連結層(Fully connected layer)是一種將高維輸入映射至低維輸出的神經網路層,其中輸入層的神經元會與輸出層的所有神經元相連結,因此可以想像,當輸出層的神經元數量小於輸入層時,輸入的特徵訊息即可被重新分類和整合至更低維的向量中。前述的CNN和zk→hk的過程皆由全連結層來達成。CLAM的注意力網路中,注意力網路會根據預期分類數目N將網路分為N條並行分支Wa,1, …, Wa, N。在這N 條分支中將有N個獨立的分類器Wc, 1…, Wc, N,這些分類器會根據前面全連結層的輸入計算出概率最高的類別,並給出分類結果。而在過程中,模型也會根據全連結層輸出的特徵跟分類的關鍵特徵進行比對並給出第k切片在i類別底下的注意力分數ai,k。ai, k的分數越高表示k切片中的特徵對於判斷類別的重要性越高。最後,在i類別下根據每張WSI所有ai, k的分佈預測出最終的類別。

CLAM 模型還具備強化學習功能,能處理分類表現差異較大的資料。對於每個類別i下擁有最高和最低注意力分數的切片會被用於額外的二元分類模型訓練,讓模型強化判斷屬於和不屬於i類別的切片。

為了評估模型的表現,CLAM會先將資料集分為驗證集和訓練集,訓練集用於模型的訓練過程,驗證集則用於衡量模型的表現。在預測驗證集資料時,CLAM利用損失函數來量化預測和真實資料的差距,並以此評估訓練成效。CLAM使用了 multi-class SVM loss (1) 作為損失指標, 是真實類別的預測分數, 則是其他所有類別分數的總和,式子 (1) 可以看到在CLAM中multi-class SVM loss 會讓 ->常數 時損失等於零,表示模型的預測成效很好,相反的當和的分數相差得不夠大時,則表示模型還可以改進。

由於傳統的SVM loss中在分類邊界的表現並不穩定,為了改進,CLAM在SVM之上又增加了溫度縮放參數,使函數趨於平滑和連續(稱為smooth SVM loss),能幫助優化神經網路。

在CLAM中,可以自定義訓練次數(epoch,訓練→驗證→計算損失和準確率的回數),在使用CLAM時,多會執行50~200次。而在每一個epoch間,模型會根據學習率(learning rate)來調整參數的變動幅度。CLAM使用了一種優化器稱為Adam optimizer來自動更新epoch間的學習率。Adam optimizer會根據歷史梯度自行調整學習率,這能讓擁有大數據的模型快速的收斂,對於資料量大的病理影像分析來說非常合適。

總結來說,整個CLAM的訓練流程是:將WSI影像切片→ 提取各切片特徵→ 進行注意力網路訓練→ 驗證訓練成果→挑出最好的參數。而在整個過程中,使用者只需提供完善的資料和對應的標註即可以得到高準確率的分類結果。

Next Post