薬剤師のプログラミング学習日記

プログラミングやコンピュータに関する記事を書いていきます

ディープラーニングで錠剤・カプセル識別システムを自作する(3)ーDMLモデルの学習と評価ー

前回記事【ディープラーニングで錠剤・カプセル識別システムを自作する(2)ー抽出・切り出しー】からの続きです。今回はデータセットの作成~DMLモデルの学習・評価について書いていきたいと思います。

 
 

深層距離学習(Deep Metric Learning)とは

DMLはデータ(今回は画像)間での似ている・似ていないという類似度を距離として、多層ニューラルネットワークを使って学習する手法です。
CNNを特徴抽出器として使う点は通常の画像分類と同じですが、最終的な出力がクラス確率ではなく、ベクトル(埋め込み)であるという点が異なります。
同じ種類の錠剤の画像は埋め込み空間で近くになるよう、異なる種類は遠くに位置するよう学習します。

学習後は全種類の錠剤について、代表となるベクトル(各クラスの平均埋め込み)をDBとして保持しておき、新たに画像を見せたときにDBの中で最も距離の近いクラスを答えとして返します。
 
  

なぜ画像分類ではなく距離学習を選んだか

錠剤・カプセルの識別には、一般的なCNNによる画像分類(「これは何のクラスか」を推論する)でクラスを出力することもできます。しかし、距離学習を使う理由としては大きく2つあります。

  • 1つ目は新しい薬が追加されたときの対応です。画像分類では対応できるクラスはモデルを学習した時点で固定されるので、新しい薬を追加するたびにモデルを再学習する必要が出てきます。DMLでは学習済みモデルの埋め込みデータベース(DB)に新しい薬の画像を登録するだけで対応できます。
  • 2つ目は見たことのない薬に対して「知らない」と言えることです。画像分類では必ずいずれかのクラスとして答えを出します。*1学習したことのない薬を見せても、最も似たクラスの名前を返してしまいます。一方、DMLでは最近傍のクラスまでの距離が閾値を超えた場合に「一致するものなし」と返すことができ、未知の入力に対してより安全な挙動が期待できます。

CNNによるクラス分類も実際に試してみたので、後の記事で書きたいと思います。

データセット概要

最終的に132種類【円形106, 楕円形26種類】、各30画像、計3960枚のデータセットを作成しました。最初から一気に集めたのではなく、学習結果を確認しながら50種類→80種類→132種類と段階的に種類を増やしていきました。
ただ、1つの種類につき、ほとんどは同一個体品を使いました。医薬品を工業製品としてみると、個体間での誤差はかなり小さいはずですが、本来であればここは別個体を複数使うべきだと思います。
 
撮影手順は以下の流れで行いました。

  1. シャーレに適当に20種類ほどの錠剤・カプセルを載せる
  2. ボックスで撮影後、その都度ピンセットで適当にかき混ぜて位置や角度を変える
  3. これを10回繰り返したら、次の20種類に入れ替える
  4. 全種類一巡したら、組み合わせを変えるために全体をシャッフルして、1に戻る

これを3セット繰り返すことで、1つの薬剤につき異なる位置や角度での撮影がされて、光の当たり方や隣に置かれる錠剤の組み合わせが毎回変わるよう工夫しました。

切り出し処理まで完了した画像は、薬品名をフォルダ名にしたディレクトリに振り分けていきます。
全部で132フォルダ×30枚ですが、種類が多いと手作業の振り分けはかなり大変なので、事前に薬品名・形状・識別コードをまとめたCSVファイルを作っておくと整理が多少楽になりました。また、このCSVは後ほどDBを作るときにも使用します。
ここではフォルダ作成とファイル名変換のため、それ用のコードを作成しました。
 
 

モデルのアーキテクチャと損失関数

実装はPyTorchで行い、学習に用いるGPUは自分のPCのRTX3070を使いました。
入力画像から特徴を抽出するためのバックボーンにはResNet18を使用しました。ImageNetで事前学習済みの重みをそのまま使い、最終層だけ差し替えています。
元々の1000クラス分類用の全結合層を、以下のEmbedding Headに置き換えることで埋め込みのサイズが128で出力されるようにしました。

model.fc = nn.Sequential(
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Linear(256, 128)
)
L2正規化

この出力の直後にL2正規化(Normalization)を適用することで、出力ベクトルの長さを1に揃えます(単位球面上への射影)。正規化後のベクトルは半径1の球面上に乗っているため、2点間の距離の範囲は0~2.0になり、距離の値が持つ意味としては次のようになります。

  • 全く同じ方向を向いているとき(同一画像):0
  • 真逆を向いているとき:2.0
  • 直交しているとき(全く無関係な方向):√2≈1.414
# p=2: L2ノルム(ユークリッド距離)で正規化
# dim=1: バッチ内の各ベクトル(128次元)単位で正規化
embeddings = F.normalize(model(images), p=2, dim=1)

 

Triplet Loss

損失関数にはTripletMarginLoss(margin=0.2)を使用しました。
モデルが正しく特徴を捉えられていれば、ある画像(Anchor)に対して、正例(Positive)を与えたときは同じようなベクトルを、負例(Negative)を与えたときは大きく違ったベクトルを出力するため、それぞれの画像ペアの距離dを比較したときは下図のようになります。

しかし、このままではわずかでも距離に差があれば式を満たしてしまい、ニューラルネットワークの学習が進みません。少なくとも一定の距離はNegativeを離しておくという制約を与えるために、ここにパラメータmarginを加えます。

 \mathrm{d(A,P) + margin < d(A,N)}

このmarginがモデルにノイズや個体差に対するある程度のロバスト性(頑健性)を与えており、最終的にTriplet損失関数の定義は以下のようになります。

 \mathrm{L(A, P,N) = max(d(A,P) - d(A,N) + margin, 0)}

バッチ内の全ペアに重み付きで勾配を分散させるMultiSimilarityLossも試しましたが、ハイパーパラメータ調整が難しく、結果がTripletMarginLossよりむしろ悪くなってしまったため、シンプルなTripletMarginLossを採用しました。

# LossとMiner
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(margin=0.2)

 
MinerにはMultiSimilarityMinerを使用しています。Minerはバッチ内から学習に有益なペア/トリプレットを選び出して損失の計算に使うことで、学習効率を上げる役割を担います。TripletMarginMinerよりも多角的な難易度評価ができるとのことで、見た目の似た錠剤が多い今回のデータには向いていると判断して使ってみました。一応TripletMarginMinerも使ってみましたが、結果的には精度は誤差レベルの違いとなりました。
 
 

データ拡張

学習時には以下のデータ拡張を適用しています。

  • RandomRotation(±180°):シャーレ上の錠剤はランダムな向きで置かれるため、任意の角度に対応させます。
  • ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3):撮影ごとの照明のわずかな差異や、隣の錠剤やシャーレ縁による光の反射、錠剤の経年変化による色の差に対応させます。
  • RandomCrop(256→224px):切り出し画像の256×256からランダムに224×224を切り出します。位置ずれや余白の扱いでの耐性を持たせるため行います。
# データ拡張
train_transform = transforms.Compose([
    transforms.RandomCrop(224),         # ランダムに224x224を切り出し
    transforms.RandomRotation(180),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 事前学習済モデルなのでそのまま
])

Flip系の拡張(左右・上下反転)は行いません。錠剤・カプセルには数字、アルファベット、メーカーロゴが刻印または印字されているものがほとんどで、このような場合は反転させると現実には存在しないものになってしまうためです。

なお、楕円形薬剤は切り出し時に水平方向に正規化しているため本来はRandomRotationが不要ですが、楕円形・円形を分けずに同一の拡張を適用しました。この点については検証の結果、精度への影響が小さかったため今回は統一した設定としています。(前回記事【楕円形薬剤の回転について】参照)
 
 

学習設定と交差検証

学習設定はBATCH_SIZE=64、EPOCHS=50、optimizer=Adam(lr=0.0001)、学習率のschedulerはCosineAnnealingLRを使用しました。バッチサイズは16, 32, 64と試していったところで、64で安定した結果が得られたため採用しました。EPOCHSも最初は30で学習していましたが、学習曲線を確認したところまだlossが下降中だったため、収束が確認できた50まで延ばしました。
 
汎化性能の確認には5分割の層化交差検証(StratifiedKFold)を使用して、各Foldでクラスあたりの枚数がtrain/testで均等になるよう分割しています。1クラスあたり30枚しかないので1回あたりのtestは6枚/クラスと少ないですが、5Fold間のばらつきが小さければ汎化性能が安定していると判断しました。

# 設定
BATCH_SIZE = 64
EPOCHS     = 50
K_FOLDS    = 5

# --- transform・DataLoader定義省略 ---

for fold, (train_ids, val_ids) in enumerate(skf.split(full_dataset, labels)):
    
    # --- DataLoader生成・モデル初期化省略 ---
    
    miner     = miners.MultiSimilarityMiner(epsilon=0.1)
    loss_func = losses.TripletMarginLoss(margin=0.2)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    for epoch in range(EPOCHS):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            embeddings = F.normalize(model(images), p=2, dim=1)
            hard_pairs = miner(embeddings, labels)
            loss       = loss_func(embeddings, labels, hard_pairs)
            loss.backward()
            optimizer.step()
        # -- 進捗表示省略 ---
        scheduler.step()

    # --- 評価フェーズ省略 ---
    
    print(f"Fold {fold+1} Recall@1: {acc:.2f}%  "
          f"Recall@2: {recall_k[2]/n_val*100:.2f}%  "
          f"Recall@5: {recall_k[5]/n_val*100:.2f}%")

print(f"Mean Recall@1: {np.mean(results):.2f}% (+/- {np.std(results):.2f})")

 
 

学習結果

学習曲線(train lossの推移)です。私の環境では終了まで50分くらいかかりました。

ややギザギザが大きく見えますがこれはminerが毎エポック難しいペアを掘り起こしてくるためDMLでは通常こうなるとのことでした。
Recall@KがDMLでの画像検索における主な評価指標であり、指定したKの中に正解が入る確率を表しています。

  • Recall@1: 最も近い1つが正解(同じクラス)である確率
  • Recall@K(例: 2, 5): 近い順に上位K個を見た時に、その中に1つでも正解クラスが含まれる確率

Recall@2は、ほぼ全Foldで100%でした。
以下は@1のmean, stdと距離についての結果です。

Mean Recall@1: 99.95% (+/- 0.10)
Avg Intra-Class Dist (Same): 0.225  # クラス内距離
Avg Inter-Class Dist (Diff): 1.392    # クラス間距離
Min Inter-Class Dist (Diff): 0.498    # クラス間最小値

上で見たようにL2正規化済み埋め込みの2点間の距離範囲は0〜2.0ですが、結果ではクラス内平均は0.225と充分に小さい値でした。
またクラス間平均は1.392と1.414に近く、128次元の埋め込み空間上でお互いがほぼ直交している(均等に分散している)状態が得られました。
 
参考値として、クラス間距離の最小値(5Fold平均)も出力しています。
クラス内距離が、クラス間距離(別クラスとの近さ)の最小値と重なり始めると距離空間上で正解と不正解の境界が曖昧になってしまい、誤分類のリスクが高まります。
今回の結果では、クラス内距離の平均(Avg Same: 0.225)に対して、クラス間距離の最小値(Min Diff: 0.498)が2倍以上のマージンがありました。データセットが小規模(30枚/クラス)なため様々なバリエーションを網羅できているとは言えませんが、現状のデータセット内においては各クラスの分離は充分できていると評価しました。

t-SNEで128次元の埋め込みを2次元に圧縮して可視化すると、ほぼすべてのクラスが独立した塊を形成していることが確認できました。

他には、間違えやすいペアの例も出力させてみました。ただ、Recall@1が99.9%を超えており、5Fold全体での誤分類は数件程度にとどまったため、(何回か回すとある程度よく見る名前はあったものの)特定の薬が間違いやすいという明確な傾向は見られませんでした。
 
今回は132種類、30枚/クラスという規模での結果でしたが、実用レベルでは対象の種類・データ量ともに桁違いになります。もし実用レベルだったらおそらく数千種類は必要になるでしょうし、画像収集も1種類につき100枚単位で必要だったのではないかと思います。(ちなみに富士フィルムの製品ページでは”数万枚以上の監査画像を学習したAIモデル”との記述がありました)
 
 

閾値の設定

識別システムとして、未知の薬剤画像に対して「わからない」と答えさせるための閾値を決めます。また、最近傍クラスまでの距離に基づく信頼度を<高・中・低(要確認)・一致なし(未特定)>の4段階判定とするため、最近傍クラスまでの距離分布を確認した上で閾値を設定することにしました。
簡潔さを優先して<高・低(要確認)・一致なし>の3段階にするか迷いましたが、どんな薬がうまく分離できていないか確認したかったので刻むことにし、システムの判定結果部分は距離に応じて以下のように表示することにしました。

 
正解・不正解の距離分布を確認したところ、正解は距離0.1〜0.4に集中しており、不正解サンプルは出ていても全体の約1%未満と非常に少ない結果でした。上でも書いたように、これだと不正解のデータが少なすぎて統計的に閾値を決めるのは難しいため*2、今回は以下のように正解サンプルの分布から閾値を決定しました。
 

  • THRESHOLD_HIGH(正解90th percentile):0.302
  • THRESHOLD_MID(正解99th percentile):0.402
  • THRESHOLD_LOW(正解最大値 × 1.1):0.520

THRESHOLD_HIGHが正解サンプルの90パーセンタイル、THRESHOLD_MIDが99パーセンタイルでこの設定では正解サンプルの90%、99%がそれぞれの範囲に収まることになります。
未知薬剤の可能性が高い「一致なし(未特定)」となる閾値は、本来は実際に未知薬剤の画像をテストして決めるべき、とのことでしたが取り敢えず今回は正解サンプルの最大距離に10%のマージンを加えた値を目安にすることにしました。

 
 
次回はArucoマーカーを使った実寸測定やDB作成からシステム完成までを書こうと思います。
 
 

*1:Softmax関数は、入力された画像を学習したクラスのどれかに必ず分類し、合計が100%になるように確率を割り振るため。

*2:グラフ右側にもう一つ不正解サンプルの分布ができていれば良かったのですが