JS.11 TensorFlow.jsでMNIST

機械学習のチュートリアルでおなじみMNISTをTensorFlow.jsでやってみました。

下のキャンバスに0~9の数字を書いて、predict ボタンを押してみてください!!

作ったもの


Number Accuracy
0 -
1 -
2 -
3 -
4 -
5 -
6 -
7 -
8 -
9 -

モデルの学習

なお、訓練はMnist cnn - Keras Documentationに従って行なっていますが、一応モデルの重みをTensorFlow.jsの形式に変換するところまでを含めてnotebookにまとめてあります。

また、TensorFlow.jsに関しては、README.mdに簡単にまとめてあります。

コード

<div class="mnist">
  <div class="drawing">
    <canvas id="drawing-pad" width="280" height="280" style="border: 2px solid;"></canvas>
    <canvas id="hidden-pad" style="display: none;"></canvas><br/>
    <button id="predict-button" class="predict" onclick="prediction()">
      <i id="loading" class="fa fa-spinner fa-spin" style="disabled: false;"></i>
    </button>
    <button id="reset-button" class="reset" onclick="reset()">
      reset
    </button>
  </div>
  <div class="result">
    <table>
      <thead>
        <tr>
          <th>Number</th>
          <th>Accuracy</th>
        </tr>
      </thead>
      <tbody>
        <tr>
          <th>0</th>
          <td class="accuracy" data-row-index="0">-</td>
        </tr>
        <tr>
          <th>1</th>
          <td class="accuracy" data-row-index="1">-</td>
        </tr>
        <tr>
          <th>2</th>
          <td class="accuracy" data-row-index="2">-</td>
        </tr>
        <tr>
          <th>3</th>
          <td class="accuracy" data-row-index="3">-</td>
        </tr>
        <tr>
          <th>4</th>
          <td class="accuracy" data-row-index="4">-</td>
        </tr>
        <tr>
          <th>5</th>
          <td class="accuracy" data-row-index="5">-</td>
        </tr>
        <tr>
          <th>6</th>
          <td class="accuracy" data-row-index="6">-</td>
        </tr>
        <tr>
          <th>7</th>
          <td class="accuracy" data-row-index="7">-</td>
        </tr>
        <tr>
          <th>8</th>
          <td class="accuracy" data-row-index="8">-</td>
        </tr>
        <tr>
          <th>9</th>
          <td class="accuracy" data-row-index="9">-</td>
        </tr>
      </tbody>
    </table>
  </div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/signature_pad/1.5.3/signature_pad.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.8.0"></script>
<script src="https://docs.opencv.org/3.4/opencv.js" type="text/javascript"></script>

JavaScript

// init SignaturePad
const drawElement = document.getElementById('drawing-pad');
const signaturePad = new SignaturePad(drawElement, {
  minWidth: 6,
  maxWidth: 6,
  penColor: 'white',
  backgroundColor: 'black',
});
// load pre-trained model
let model;
const model_path = '../../js/TensorFlowJs/MNIST/tf-model/model.json'
tf.loadModel(model_path)
  .then(function(pretrainedModel){
    model = pretrainedModel;
    document.getElementById('predict-button').innerHTML = "predict";
    document.getElementById('loading').style.disabled="true";
  });
function getImageData() {
  // grayscale
  const src = cv.imread(drawElement);
  let dst_gray = new cv.Mat();
  cv.cvtColor(src, dst_gray, cv.COLOR_RGBA2GRAY, 0);
  // resize
  let dst_resized = new cv.Mat();
  let dsize = new cv.Size(28, 28);
  cv.resize(dst_gray, dst_resized, dsize, 0, 0, cv.INTER_AREA);
  cv.imshow('hidden-pad', dst_resized);
  const imageData = document.getElementById('hidden-pad').getContext('2d').getImageData(0, 0, 28, 28);
  src.delete();
  dst_gray.delete();
  dst_resized.delete();
  return imageData;
}
function getAccuracyScores(imageData) {
  // メモリリークの心配がなくなる。自動的にメモリを解放。
  const score = tf.tidy(function () {
    const channels = 1;
    let input = tf.fromPixels(imageData, channels);
    input = tf.cast(input, 'float32').div(tf.scalar(255));
    input = input.expandDims();
    return model.predict(input).dataSync();
  });
  return score;
}
function prediction() {
  const imageData = getImageData();
  const accuracyScores = getAccuracyScores(imageData);
  const maxAccuracy = accuracyScores.indexOf(Math.max.apply(null, accuracyScores));
  const elements = document.querySelectorAll(".accuracy");
  elements.forEach(function (el){
    el.parentNode.classList.remove('is-selected');
    const rowIndex = Number(el.dataset.rowIndex);
    if (maxAccuracy === rowIndex) {
      el.parentNode.classList.add('is-selected');
    }
    el.innerText = accuracyScores[rowIndex];
  })
}
function reset() {
  signaturePad.clear();
  let elements = document.querySelectorAll(".accuracy");
  elements.forEach(function (el){
    el.parentNode.classList.remove('is-selected');
    el.innerText = '-';
  })
}

css

.mnist {
  padding: 10px;
  width: 100%;
  overflow: hidden;
}
.drawing {
  float: left;
  width: 50%;
  text-align: center;
}
.result {
  float: right;
  width: 50%;
}
.predict{
  padding: 10px;
  background-color: #80160e;
}
.reset {
  padding: 10px;
  background-color: #c8c8a0;
}
.is-selected {
  background-color: #80160e;
  color: white;
}
@media only screen and (max-width: 1200px) {
  .drawing {
    width: 100%;
    text-align: center;
  }
  .result {
    width: 100%;
  }
}
@media only screen and (max-width: 760px) {
  .drawing {
    float: left;
    width: 50%;
    text-align: center;
  }
  .result {
    float: right;
    width: 50%;
  }
}
@media only screen and (max-width: 640px) {
  .drawing {
    width: 100%;
    text-align: center;
  }
  .result {
    width: 100%;
  }
}
other contents
social