モーダルを閉じる工作HardwareHub ロゴ画像

工作HardwareHubは、ロボット工作や電子工作に関する情報やモノが行き交うコミュニティサイトです。さらに詳しく

利用規約プライバシーポリシー に同意したうえでログインしてください。

PyTorch を用いた手書き数字の分類

モーダルを閉じる

ステッカーを選択してください

お支払い手続きへ
モーダルを閉じる

お支払い内容をご確認ください

購入商品
」ステッカーの表示権
メッセージ
料金
(税込)
決済方法
GooglePayマーク
決済プラットフォーム
確認事項

利用規約をご確認のうえお支払いください

※カード情報はGoogleアカウント内に保存されます。本サイトやStripeには保存されません

※記事の執筆者は購入者のユーザー名を知ることができます

※購入後のキャンセルはできません

作成日作成日
2020/07/21
最終更新最終更新
2022/11/02
記事区分記事区分
一般公開

目次

    C/C++ の基礎知識を初心者向けに紹介しています!

    PyTorch を用いて、手書き数字の分類を行ってみます。サポートベクターマシンを用いた場合は HOG などの特徴量を考える必要がありましたが、ディープラーニングでは十分な質の良いデータがあればその必要がありません。

    MNIST データの読み込み

    手書き数字のデータとして、MNIST データをダウンロードして利用することにします。Matplotlibで描画する例は以下のようになります。

    # -*- coding: utf-8 -*-
    import gzip
    import pickle
    import matplotlib.pyplot as plt
    import torch
    
    def Main():
    
        # pickle 形式で保存されています。
        with gzip.open('mnist.pkl.gz', 'rb') as f:
            ((xTrain, yTrain), (xValid, yValid), _) = pickle.load(f, encoding='latin-1')
    
        # 28x28 ピクセルの画像データが 50000 枚分あります。
        print(xTrain.shape)  #=> (50000, 784)
    
        # 描画してみます。
        print(yTrain[0])  #=> 5
        plt.imshow(xTrain[0].reshape(28, 28), cmap='gray')
        plt.show()
    
        # pytorch で利用するためには torch.tensor に変換します。
        print(type(xTrain[0]))  #=> <class 'numpy.ndarray'>
    
        xTrain, yTrain, xValid, yValid = map(
            torch.tensor, (xTrain, yTrain, xValid, yValid)
        )
        print(type(xTrain[0]))  #=> <class 'torch.Tensor'>
    
    if __name__ == '__main__':
        Main()
    

    ニューラルネットワークの定義

    手書き数字の描かれた画像を分類するニューラルネットワークとして、ディープラーニングでよく利用される「畳み込みニューラルネットワーク (CNN; Convolutional Neural Network)」を用いてみます。torch.nn.Module を継承したクラスを利用してネットワークを定義できます。

    MNIST データの分類を考えたときには、以下のようなネットワーク定義となります。ただしこれは CNN の一つの例であり、一般形ではありません。

    #!/usr/bin/python
    # -*- coding: utf-8 -*-
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    def Main():
    
        # MNIST データは RGB ではなくグレースケールです。
        inChannels = 1
    
        # 0 から 9 までの数字への分類を考えます。
        outFeatures = 10
    
        # MNIST データは 28x28 の画像です。
        inputSize = 28
    
        # ネットワークを定義します。
        cnn = CNN(inChannels, outFeatures, inputSize)
    
        # 一つのミニバッチに含まれるデータの個数
        bs = 1
    
        # 乱数で MNIST と同じサイズのデータを用意してみます。
        x = torch.randn(bs, inChannels, inputSize, inputSize)
        yPred = cnn(x)
        print(cnn)
        print(yPred.shape)
    
    
    class CNN(nn.Module):
    
        def __init__(self, inChannels, outFeatures3, inputSize):
            super(CNN, self).__init__()
    
            # 隠れ層の次元数など
            outChannels = 6
            kernelSize = 3
            outChannels2 = 16
            outFeatures = 120
            outFeatures2 = 84
            poolingStride = 2
    
            sz = inputSize - kernelSize + 1
            sz = sz // poolingStride
            sz = sz - kernelSize + 1
            sz = sz // poolingStride
    
            self.__poolingStride = poolingStride
            self.__conv1 = nn.Conv2d(inChannels, outChannels, kernelSize)
            self.__conv2 = nn.Conv2d(outChannels, outChannels2, kernelSize)
            self.__fc1 = nn.Linear(outChannels2 * sz * sz, outFeatures)
            self.__fc2 = nn.Linear(outFeatures, outFeatures2)
            self.__fc3 = nn.Linear(outFeatures2, outFeatures3)
    
        def forward(self, x):
    
            # 1 x 1 x 28 x 28
    
            x = self.__conv1(x)  #=> 1 x 6 x 26 x 26
            x = F.relu(x)  #=> 1 x 6 x 26 x 26
            x = F.max_pool2d(x, self.__poolingStride)  #=> 1 x 6 x 13 x 13
    
            x = self.__conv2(x)  #=> 1 x 16 x 11 x 11
            x = F.relu(x)  #=> 1 x 16 x 11 x 11
            x = F.max_pool2d(x, self.__poolingStride)  #=> 1 x 16 x 5 x 5
    
            # note: 第一引数を -1 とすることで、第二引数の値から形状を推定させることができます。
            x = x.reshape(-1, self.__GetNumFlatFeatures(x))  #=> 1 x 400
    
            x = F.relu(self.__fc1(x))  #=> 1 x 120
            x = F.relu(self.__fc2(x))  #=> 1 x 84
            x = self.__fc3(x)  #=> 1 x 10
            return x
    
        def __GetNumFlatFeatures(self, x):
            size = x.size()[1:]  # ミニバッチの個数の次元を除く、すべての次元
            numFeatures = 1
            for sz in size:
                numFeatures *= sz
            return numFeatures
    
    if __name__ == '__main__':
        Main()
    

    Conv2d(inChannels, outChannels, kernelSize)

    inChannels 方向には動かさず、画像の平面内で畳み込みを行います。この畳み込みを独立に outChannels 個のフィルタで行い、結果を一つのテンソルとしてまとめます。kernelSize は OpenCV での畳み込み処理におけるカーネルと同じ概念です。

    Linear(inFeatures, outFeatures)

    線形変換です。重みとバイアスをパラメータとして持ちます。

    max_pool2d(x, stride)

    stride x stride において最大となる値をフィルタします。

    出力例

    CNN(
      (_CNN__conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
      (_CNN__conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
      (_CNN__fc1): Linear(in_features=400, out_features=120, bias=True)
      (_CNN__fc2): Linear(in_features=120, out_features=84, bias=True)
      (_CNN__fc3): Linear(in_features=84, out_features=10, bias=True)
    )
    torch.Size([1, 10])
    

    分類問題で利用する損失関数について

    回帰問題を扱う際に利用した平均二乗誤差は、分類問題ではそのまま利用できません。そのため、ここでは手書き数字の分類を行うために、以下の式で表される、交差エントロピー誤差という損失関数を用いてニューラルネットワークを学習します。

    L=1Nj=1N=64i=110ti,jlog(pi,j)L = -\frac{1}{N} \sum_{j=1}^{N=64} \sum_{i=1}^{10} t_{i,j} \log( p_{i,j} )

    MNIST データには学習用のデータが 50000 枚あります。ディープラーニングでパラメータの学習のためにループを回す際に、利用可能なすべての学習用のデータを分割して、小さなバッチデータ毎にループを回す手法があります。本ページではミニバッチのサイズ NN を 64 として学習することにします。

    ある一つの画像データをニューラルネットワークに入力として与えると、入力画像が 0-9 の数字である確率が、長さ 10 のベクトルとして出力されます。実際には NN 個のデータを一度に入力するため、このベクトルが NN 個出力されます。

    交差エントロピー誤差では、10 の長さのベクトルのうち、例えば入力画像が 0 という数字であった場合は、最初の要素だけを取り出して対数を取ります。10 個の確率から一つのデータを取り出せるように ti,jt_{i,j} は 0 または 1 の値を取ります。NN 個のデータについて同様の処理を行い、平均を計算したものが誤差となります。

    例えば log(1)\log(1) は 0 となるため、正しい分類ができている場合の誤差は 0 となります。

    ソフトマックス関数について

    交差エントロピー誤差を計算するためには、ニューラルネットワークの出力を確率として扱えるように変換する必要があります。PyTorch では交差エントロピー誤差を計算する関数 nn.CrossEntropyLoss の内部で、ソフトマックス関数 nn.Softmax を利用して出力を確率として扱えるように変換しています。

    nn.CrossEntropyLoss の利用例

    以下では nn.CrossEntropyLoss で計算した誤差と、定義に基いて手動計算した誤差が一致することを確認しています。

    import torch
    import torch.nn as nn
    
    lossFn = nn.CrossEntropyLoss()
    
    N = 64
    y = torch.empty(N, dtype=torch.long).random_(10)
    
    yPred = torch.randn(N, 10)
    
    loss = lossFn(yPred, y)
    
    loss2 = 0.0
    for j in range(N):
        loss2 += -torch.log(torch.exp(yPred[j][y[j]]) / sum(torch.exp(yPred[j])))
    loss2 /= N
    
    print(loss)  #=> 2.8512
    print(loss2)  #=> 2.8512
    

    ニューラルネットワークの学習

    上述の CNN と交差エントロピー誤差を用いて MNIST データの分類を試してみます。

    # -*- coding: utf-8 -*-
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import TensorDataset
    from torch.utils.data import DataLoader
    
    import gzip
    import pickle
    import matplotlib.pyplot as plt
    
    def Main():
    
        # MNIST データ
        xTrain, yTrain, xValid, yValid = GetMnistData()
    
        # CNN モデル
        inChannels = 1
        outFeatures = 10
        inputSize = 28
        model = CNN(inChannels, outFeatures, inputSize)
    
        # 交差エントロピー誤差
        lossFn = F.cross_entropy
    
        # 学習率、学習の反復回数
        learningRate = 0.001
        iters = 10
    
        # 最適化関数
        optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)
    
        # ミニバッチのサイズ
        bs = 64
        trainDs = TensorDataset(xTrain, yTrain)
        trainDl = DataLoader(trainDs, batch_size=bs, shuffle=True)
    
        # 全体のループ
        for t in range(iters):
    
            # ミニバッチ毎のループ
            for x, y in trainDl:
                yPred = model(x)
                loss = lossFn(yPred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            # 誤差の出力
            print(t, loss.item())
    
        # 学習済みモデルの検証
        CheckOutput(model, xTrain, yTrain)
        CheckOutput(model, xValid, yValid)
    
    def CheckOutput(model, x, y):
        yPred = map(lambda xx: xx.max(0).indices.item(), model(x))
        wrong = 0
        for xx, yy, yyPred in zip(x, y, yPred):
            if yy.item() == yyPred:
                continue
            # print('{} != {}'.format(yy.item(), yyPred))
            # plt.imshow(xx.reshape(28, 28), cmap='gray')
            # plt.show()
            wrong += 1
        print('Accuracy: {}'.format(100 - wrong / len(x) * 100))
    
    def GetMnistData():
        with gzip.open('mnist.pkl.gz', 'rb') as f:
            (xTrain, yTrain), (xValid, yValid), _ = pickle.load(f, encoding='latin-1')
        xTrain = list(map(lambda x: x.reshape(1, 28, 28), xTrain))
        xValid = list(map(lambda x: x.reshape(1, 28, 28), xValid))
        return map(torch.tensor, (xTrain, yTrain, xValid, yValid))
    
    class CNN(nn.Module):
    
        def __init__(self, inChannels, outFeatures3, inputSize):
            super(CNN, self).__init__()
            outChannels = 6
            kernelSize = 3
            outChannels2 = 16
            outFeatures = 120
            outFeatures2 = 84
            poolingStride = 2
            sz = inputSize - kernelSize + 1
            sz = sz // poolingStride
            sz = sz - kernelSize + 1
            sz = sz // poolingStride
            self.__poolingStride = poolingStride
            self.__conv1 = nn.Conv2d(inChannels, outChannels, kernelSize)
            self.__conv2 = nn.Conv2d(outChannels, outChannels2, kernelSize)
            self.__fc1 = nn.Linear(outChannels2 * sz * sz, outFeatures)
            self.__fc2 = nn.Linear(outFeatures, outFeatures2)
            self.__fc3 = nn.Linear(outFeatures2, outFeatures3)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.__conv1(x)), self.__poolingStride)
            x = F.max_pool2d(F.relu(self.__conv2(x)), self.__poolingStride)
            x = x.reshape(-1, self.__GetNumFlatFeatures(x))
            x = F.relu(self.__fc1(x))
            x = F.relu(self.__fc2(x))
            x = self.__fc3(x)
            return x
    
        def __GetNumFlatFeatures(self, x):
            size = x.size()[1:]
            numFeatures = 1
            for sz in size:
                numFeatures *= sz
            return numFeatures
    
    if __name__ == '__main__':
        Main()
    

    実行例

    0 0.10681144148111343
    1 0.06327979266643524
    2 0.040145404636859894
    3 0.015086745843291283
    4 0.005156606901437044
    5 0.0018728474387899041
    6 0.000745357247069478
    7 0.0005938038229942322
    8 5.65591617487371e-05
    9 0.0003749439201783389
    Accuracy: 99.456
    Accuracy: 98.58
    

    訓練用のデータで 99.456%、未知のデータで 98.58% となりました。分類に失敗したデータの例としては以下のようなものがあります。

    2 と認識 (正しくは 3)

    8 と認識 (正しくは 3)

    6 と認識 (正しくは 5)

    Likeボタン(off)0
    詳細設定を開く/閉じる
    アカウント プロフィール画像

    C/C++ の基礎知識を初心者向けに紹介しています!

    記事の執筆者にステッカーを贈る

    有益な情報に対するお礼として、またはコメント欄における質問への返答に対するお礼として、 記事の読者は、執筆者に有料のステッカーを贈ることができます。

    >>さらに詳しくステッカーを贈る
    ステッカーを贈る コンセプト画像

    Feedbacks

    Feedbacks コンセプト画像

      ログインするとコメントを投稿できます。

      ログインする

      関連記事