PyTorch の基本的な使い方を記載します。
テンソルの計算
初期化されていない空のテンソルを作成
乱数で初期化 (一様分布、標準正規分布)
0 で初期化
1 で初期化
値を指定して初期化
同じ形状のテンソルを作成
サイズを取得
形状を変更
テンソルの値を取り出す
NumPy に変換
NumPy から変換
自動微分
バックプロパゲーションによる微分が利用できます。
これを PyTorch で表現すると以下のようになります。
PyTorch で実装されている関数には backward 処理が実装されているため、例えば以下のような ReLU (rectified liner unit) を考えたときにも対応できます。
torch.clamp を利用します。
非線形なデータセットを用いた 2 層のニューラルネットワークの学習
一般的なニューラルネットワークは
「線形変換」→「活性化関数 → 線形変換」→ ... →「活性化関数 → 線形変換」
という構造になっています。活性化関数には上述の ReLU や、以下のシグモイド関数があります。活性化関数では非線形な変換が行われるため、ニューラルネットワークでは非線形なデータセットにも対応できます。
活性化関数にはパラメータがありませんが、線形変換には重み w
とバイアス b
パラメータがあります。以下の例のように、パラメータを持つ層が二つ存在する場合は 2 層のニューラルネットワークとよびます。入力データ $x_i$
と出力データ $y_i$
をもとに、ニューラルネットワークのパラメータ w1
、w2
、b1
、b2
を学習します。損失関数 (Loss Function) として用いている平均二乗誤差 Mean Squared Error (MSE) が小さくなるように、勾配降下法によって最適化を行います。
実行例
データのプロットでは Matplotlibを利用しています。
PyTorch の関数は自分で追加することができます。forward
と backward
を実装します。シグモイド関数の例は以下のようになります。
これを用いると上記サンプルプログラムは以下のように変更できます。バックプロパゲーションの計算グラフが簡単になるためメモリ効率が良くなります。
既存の nn モジュールを組み合わせてモデルを構築
torch.nn
ではよく利用するモジュールが提供されています。モジュールを組み合わせてモデルを作ることで、先程の例における、線形変換およびシグモイド関数による 2 層のニューラルネットワークを構築できます。損失関数も提供されているものを利用できます。
結果はモデルを利用しない場合と同様です。
以下のように独自のモデルを定義して使うこともできます。
torch.optim の利用
ニューラルネットワークのパラメータを学習する際には最適化問題を解きます。ここまでの例では、簡単な勾配降下法を実装して利用していましたが、torch.optim
で提供されているものを利用することもできます。勾配降下法の一つである SGD (stochastic gradient descent) や、Adam を利用するためには以下のようにします。
実行例
学習済みモデルのファイル保存
学習したパラメータはファイルシステムに保存することができます。
読み込んで利用
記事の執筆者にステッカーを贈る
有益な情報に対するお礼として、またはコメント欄における質問への返答に対するお礼として、 記事の読者は、執筆者に有料のステッカーを贈ることができます。
さらに詳しく →Feedbacks
ログインするとコメントを投稿できます。
関連記事
- Python コードスニペット (条件分岐)if-elif-else sample.py #!/usr/bin/python # -*- coding: utf-8 -*- # コメント内であっても、ASCII外の文字が含まれる場合はエンコーディング情報が必須 x = 1 # 一行スタイル if x==0: print 'a' # 参考: and,or,notが使用可能 (&&,||はエラー) elif x==1: p...
- Python コードスニペット (リスト、タプル、ディクショナリ)リスト range 「0から10まで」といった範囲をリスト形式で生成します。 sample.py print range(10) # for(int i=0; i<10; ++i) ← C言語などのfor文と比較 print range(5,10) # for(int i=5; i<10; ++i) print range(5,10,2) # for(int i=5; i<10;...
- ZeroMQ (zmq) の Python サンプルコードZeroMQ を Python から利用する場合のサンプルコードを記載します。 Fixing the World To fix the world, we needed to do two things. One, to solve the general problem of "how to connect any code to any code, anywhere". Two, to wra...
- Matplotlib/SciPy/pandas/NumPy サンプルコードPython で数学的なことを試すときに利用される Matplotlib/SciPy/pandas/NumPy についてサンプルコードを記載します。 Matplotlib SciPy pandas [NumPy](https://www.numpy
- pytest の基本的な使い方pytest の基本的な使い方を記載します。 適宜参照するための公式ドキュメントページ Full pytest documentation API Reference インストール 適当なパッケージ