manvaのエンジニアリング魂

エンジニアリング・ものづくり・DIYをもっと身近にするためのブログ。インスピレーションを刺激します。

Jetson NanoにJAXをインストールする方法

Jetson Nano 2GB版↓にJAXをインストールできたのでやり方を書いておく。
www.switch-science.com


JAXとは、GPU対応のNumpyのようなものらしい。知らんけど。なぜJetson NanoやJAXを始めたのかは後日別途語りたい。

まずはJAXの公式サイト↓
github.com
の手順で普通にインストールしようとしたが、この手順はCUDA10+cuDNN7かCUDA11+cuDNN8の組み合わせにしか対応しておらず、Jetson NanoのデフォルトのSDカードイメージでは,CUDAはバージョン10.2、cuDNNはバージョン8であるため、うまく行かなかった。それ以外の組み合わせはソースからのビルドが必要になる。
調べてみると、ソースからビルドすることで、Jetson Nano 4GB版でインストールできたという人がいて、やり方を書いてくれている↓(以下、これのことを「インストラクション★」と記載する)。
https://forums.developer.nvidia.com/t/jax-on-jetson-nano/182593/9
Jetson NanoのデフォルトのSDカードイメージでは、Pythonのバージョンが3.6.9なのだが、JAXをビルドするにはバージョン3.9に上げる必要がある。しかし、デフォルトのpython3をバージョン3.9に置き換えてしまうと他の多くのソフトの依存関係が壊れてしまうようだ。そのため、virtualenvというのを使ってデフォルトのPythonバージョンを3.6.9のまま変えないで、必要なときだけPython3.9を使えるようにインストールしている。
この情報のおかげで非常に助かったが、その通りにやってもまだうまくいかない。ビルドの途中で,以下のようなエラーが出て失敗する。

jaxlib/cuda_lu_pivot_kernels.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command 
nvcc fatal   : Unsupported gpu architecture 'compute_80'
Target //build:build_wheel failed to build
$ nvcc --help

と打って確認してみると,--gpu-architectureオプションの説明のところに'compute_80'がない。'compute_75'まで。'compute_80' はCUDAのバージョンが11以降でないと使えないらしい。たぶん,インストラクション★が書かれた後でJAXがバージョンアップしたためにこのエラーが出るようになったと思われる。CUDAのバージョンを上げるとまたいろいろ問題発生しそうなので、JAXの方を古いバージョンに落として使うことにした。
Jetson Nanoを公式の手順、または↓の本の手順で立ち上げたところからの手順を以下にまとめる。

(以下、この本を「超入門」と記載する)

インストラクション★では、USB3.0のポートにSSDをつないで使っているが、SDカード64GB以上あればそれだけでできそう。

手順

1.スワップ領域追加

インストラクション★は,Jetson Nanoのメモリ4GB版用に書かれている。Jetson Nano 4GB版ではデフォルトでスワップ領域が設定されておらず,スワップ領域は5GBでも足りないくらいで10GBにしたとのこと。Jetson Nano 2GB版ではデフォルトでスワップ領域が5GBあるが,10GBに増やしておく。
まずはデフォルトのスワップ領域を一旦削除。

$ sudo swapoff /swapfile

スワップ領域の追加は「超入門」の通りにした。

$ sudo dd if=/dev/zero of=/var/swapfile bs=1G count=10
$ sudo mkswap /var/swapfile
$ sudo chmod 600 /var/swapfile

/etc/fstabというファイルを修正する。

$ sudo gedit /etc/fstab

と打ってgeditで開く(geditはsudoで使うと終了時に警告が出るが別に問題なさそう)と、最後の方に

# <file system> <mount point>             <type>          <options>                               <dump> <pass>
/dev/root            /                     ext4           defaults                                     0 1
/swapfile            swap                  swap           defaults                                     0 0

とある。最後の行の/swapfileを/var/swapfileに書き換えて保存。

# <file system> <mount point>             <type>          <options>                               <dump> <pass>
/dev/root            /                     ext4           defaults                                     0 1
/var/swapfile        swap                  swap           defaults                                     0 0

↓のように打つと

$ sudo swapon /var/swapfile

スワップ領域が確保される。
正しく設定できているかどうかは、

$ free -m

と打つと確認できる。

2..bashrcにパスを追加
$ gedit ~/.bashrc

で開いたファイルの最後に↓追加。

export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

保存して一旦再起動。

$ export

と打つと正しく設定できたか確認できる。

3.Python3.9をインストール
$ sudo add-apt-repository ppa:deadsnakes/ppa 
$ sudo apt update
$ sudo apt install python3.9 python3.9-dev

ここで、インストラクション★にないがもう1つ追加で以下を実行する必要がある。

$ sudo apt install python3.9-distutils

ここまでで、Python3.9はインストールされるが、python3と打った場合には今まで通りPython3.6.9が使われる状態になる。

$ python3 -V

と打ってみると、3.6.9になっているはず。

4.仮想環境virtualenvをインストール
$ sudo apt install python3-pip
$ sudo -H pip3 install virtualenv
$ virtualenv -p /usr/bin/python3.9 py39

この後、↓のように打つと、プロンプトの先頭に(py39)と表示され、Python3.9が使われる状態になる。

$ source ~/py39/bin/activate

解除したいときは

$ deactivate

と打つ。

$ python3 -V

と打ってみると、3.9.9になっているはず。

この仮想環境は何もインストールされていない状態なので、必要なものをインストールする。

$ python3 -m pip install numpy scipy six wheel
$ sudo apt install g++

 

5.JAXをgitクローン

インストラクション★の通りまず↓を実行した(が、いらなかった)。

$ git clone https://github.com/google/jax

これだと現時点(2022/1/6)での最新バージョン0.2.26を使うことになるが、上記のようにうまく行かなかったため、旧バージョンのJAX0.2.20を使ったらできた。↓からJAX0.2.20をダウンロード、解凍してgitクローンしたjaxフォルダを削除して置き換えた。
Releases · google/jax · GitHub

私はこのやり方で旧バージョンにしたが、だったらたぶん最初から↓でできたと思う(未確認)。

$ git clone https://github.com/google/jax -b jax-v0.2.20 --depth 1

また、このバージョンは適当に選んだもので他は試していない。試していないのでわからないが、もう少し新しいバージョンでもできるかもしれない。


後はビルドするただけなのだが、これが非常に時間がかかる。最初ビルドしたとき、丸2日経っても終わらず、いつまで続くかもわからなかったので強制終了して高速化のために効きそうな対策をいくつか施した。どれが効いたかわからないが、半分以下の時間でできるようになった。

6. 高速化のための対策
  • 電源アダプタを5V2Aのものから5V3Aのものに変える  時々、System throttled due to overcurrentなどと出ていたので。
  • モニタ、キーボード、マウスをつないでいたが、全部外して、Wifiから小さめの解像度でリモートデスクトップ接続する  メモリを節約できると思って。
  • 「超入門」に書かれていたパフォーマンス最大化方策
$ sudo nvpmodel -m 0
$ sudo jetson_clocks

 

7. ビルド
$ cd jax
$ python3 build/build.py --enable_cuda

インストラクション★では12時間かかったそうだが、私の場合19時間くらいだった。メモリ4GBと2GBの差か。SDカードよりUSB3.0SSDの方が速い、というのもあるかもしれない。
途中、分数(の形をした何かの数字)で進捗が表示されるが、分母も大きくなるのでいつまで続くかわからない。途中いくつかメモした。3425/3941→7420/8065→9212/9632→10394/10618といった感じ。参考までに、私の場合、分母は10618が最後だった。最初は速いが9000ぐらいから非常に遅くなる。

終わったら、最後に以下実行。

$ pip3 install dist/*.whl
$ pip3 install -e .

(最後の . (ドット)を見落とさないように注意)
これでできるはず。