BNN-PYNQ の走り書き

アドレスはハードコーディング

library/driver/platform-xlnk.cpp:               platform = new XlnkDriver(0x43c00000, 64 * 1024);

FoldedMVInit でメモリアロケーションしている。accelBufIn と AccelBufOut は thePlatform->allocAccelBuffer で取得している。これは最終的に sds_mmap を呼ぶ。sds つかってなくてもね。sds_mmap は UIO の一部を使っているか使ってないかを無視してマップするので危険。使っているか使っていないかはドライバやアプリしだいなので、つかってくれるな~~とお願いしながら使うことになる(と思う。改善されてるかもしれない。)

で IP にこの物理アドレスを教えてあげる。

    thePlatform->write64BitJamRegAddr(0x10, (AccelDblReg) accelBufIn);
    thePlatform->write64BitJamRegAddr(0x1c, (AccelDblReg) accelBufOut);
    thePlatform->writeJamRegAddr(0x28, 0);

0x28 へのアクセスは // disable weight loading mode ということみたい。

writeJamRegAddr は最初の 32 個のシステム用のアドレスをスキップして(つまり 32xsizeof(uint32_t) = 0x80) てアクセスする。この場合 0x43c00000 + 0x80 + 0x28 だね。64bit の場合はリトルエンディアンで 2回 writeJamRegAddr を呼ぶ。

testPrebinarized_nolabel で実際の処理。1つのワードが 64bit ( 8x8 )で、これに切り上げられる。例えば、この場合 28x28 = 784 だけど、832 ビットになる。この LFC ネットワークの MNIST はグレースケールを二値化(0 と 0xff) になっていることを前提としていて、それをされに 0 と 1 に落とし込む。ので 832ビット。

参考までに書くとゼロから作る DL の本では 784 バイトでグレースケールを考慮している。 隠れが1層で 784(バイト) => (50x100) => 10だった?BNN-PYNQ は 832bit => 1000x1000x1000 => 10 。方や float で方や二値。たぶん、精度はゼロから作るの方がいいのか微妙。BNN-PYNQ で私の手書きの9は7になってしまった。

さて、ドライバに戻る。

立ち上げた瞬間はまだネットワークのウェイトが設定されていない。load_parameter で設定する。

extern "C" void load_parameters(const char* path)
{
#include "config.h"
FoldedMVInit("lfc-pynq");
network<mse, adagrad> nn;
makeNetwork(nn);
        cout << "Setting network weights and thresholds in accelerator..." << endl;
        FoldedMVLoadLayerMem(path, 0, L0_PE, L0_WMEM, L0_TMEM);
        FoldedMVLoadLayerMem(path, 1, L1_PE, L1_WMEM, L1_TMEM);
        FoldedMVLoadLayerMem(path, 2, L2_PE, L2_WMEM, L2_TMEM);
        FoldedMVLoadLayerMem(path, 3, L3_PE, L3_WMEM, L3_TMEM);
}

L0_PE = 32, L0_WMEM=416, L0_TMEM=32 う~ん。PE(Processor Entity?)とかSIMD とかよくわからんが、、、各レイヤーのパラメタを設定する。
L0_WMEM が 416 で L0_SIMD が 64 なので 416 x (64 / 32) = 832 bit が入力。
L0_TMEM が 32 で LO_PE が 32 なので 1024 bit が出力。
L1 は 入力 L1_SIMD / 32 * L1_WMEM => 32 / 32 * 512 => 512(あれ?)出力は 64 * 16 = 1024。てな具合にやっていくのでしょう。なんか計算が合わないけど。
HLS 側のソース見ると

static ap_uint<L0_SIMD> weightMem0[L0_PE][L0_WMEM];
static ap_fixed<24, 16> thresMem0[L0_PE][L0_TMEM];
static ap_uint<L1_SIMD> weightMem1[L1_PE][L1_WMEM];
static ap_uint<16> thresMem1[L1_PE][L1_TMEM];
static ap_uint<L2_SIMD> weightMem2[L2_PE][L2_WMEM];
static ap_uint<16> thresMem2[L2_PE][L2_TMEM];
static ap_uint<L3_SIMD> weightMem3[L3_PE][L3_WMEM];
static ap_uint<16> thresMem3[L3_PE][L3_TMEM];
static ap_uint<L4_SIMD> weightMem4[L4_PE][L4_WMEM];
static ap_uint<16> thresMem4[L4_PE][L4_TMEM];
static ap_uint<L5_SIMD> weightMem5[L5_PE][L5_WMEM];
static ap_uint<16> thresMem5[L5_PE][L5_TMEM];
static ap_uint<L6_SIMD> weightMem6[L6_PE][L6_WMEM];
static ap_uint<16> thresMem6[L6_PE][L6_TMEM];
static ap_uint<L7_SIMD> weightMem7[L7_PE][L7_WMEM];
static ap_uint<16> thresMem7[L7_PE][L7_TMEM];
static ap_uint<L8_SIMD> weightMem8[L8_PE][L8_WMEM];

ん?なんで L8 まであるんだ?

    switch (targetLayer) {
    case 0:
        weightMem0[targetMem][targetInd] = val;
        break;

いずれにせよこんな感じで入れている。だから、targetLayerと targetMem と targetInd を設定して val をいれればよい。
バイナリデータがあるので64bit ずつ、

  FoldedMVMemSet(layerNo*2, pe, line, e);

てなかんじ。ファイル名は layerno + pe + "-weights.bin" or "-thres.bin" だ。そうか、sigmoid つかうかわりにスレッショルド使っているのか(予測)。

  // enable weight loading mode
  thePlatform->writeJamRegAddr(0x28, 1);
  // set up init data
  thePlatform->writeJamRegAddr(0x30, targetLayer);
  thePlatform->writeJamRegAddr(0x38, targetMem);
  thePlatform->writeJamRegAddr(0x40, targetInd);
  thePlatform->write64BitJamRegAddr(0x48, (AccelDblReg) val);
  // do write
  ExecAccel();
  // disable weight loading mode
  thePlatform->writeJamRegAddr(0x28, 0);

これで値設定。設定するたびに待っているぞ。

void ExecAccel() {
  // invoke accelerator and wait for result
  thePlatform->writeJamRegAddr(0x00, 1);
  while((thePlatform->readJamRegAddr(0x00) & 0x2) == 0) usleep(1);
}

あとは 28x28 をビットに整形して(binarizeAndPack)、IP がアクセスできる連続領域にさらにコピーして(copyBufferHostToAccel)、実行

   FoldedMVOffloadBinarized(binImages, outLabel, count*psi, count*psl, count);

....
void FoldedMVOffloadBinarized(....

  thePlatform->writeJamRegAddr(0x54, numImages);

  // launch
  ExecAccel();

結果を accelBufOut からもってくる。