機械学習のチュートリアルでおなじみMNISTをTensorFlow.js
でやってみました。
下のキャンバスに0~9の数字を書いて、predict
ボタンを押してみてください!!
作ったもの
Number | Accuracy |
---|---|
0 | - |
1 | - |
2 | - |
3 | - |
4 | - |
5 | - |
6 | - |
7 | - |
8 | - |
9 | - |
モデルの学習
なお、訓練はMnist cnn - Keras Documentationに従って行なっていますが、一応モデルの重みをTensorFlow.js
の形式に変換するところまでを含めてnotebookにまとめてあります。
また、TensorFlow.js
に関しては、README.mdに簡単にまとめてあります。
コード
<div class="mnist">
<div class="drawing">
<canvas id="drawing-pad" width="280" height="280" style="border: 2px solid;"></canvas>
<canvas id="hidden-pad" style="display: none;"></canvas><br/>
<button id="predict-button" class="predict" onclick="prediction()">
<i id="loading" class="fa fa-spinner fa-spin" style="disabled: false;"></i>
</button>
<button id="reset-button" class="reset" onclick="reset()">
reset
</button>
</div>
<div class="result">
<table>
<thead>
<tr>
<th>Number</th>
<th>Accuracy</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td class="accuracy" data-row-index="0">-</td>
</tr>
<tr>
<th>1</th>
<td class="accuracy" data-row-index="1">-</td>
</tr>
<tr>
<th>2</th>
<td class="accuracy" data-row-index="2">-</td>
</tr>
<tr>
<th>3</th>
<td class="accuracy" data-row-index="3">-</td>
</tr>
<tr>
<th>4</th>
<td class="accuracy" data-row-index="4">-</td>
</tr>
<tr>
<th>5</th>
<td class="accuracy" data-row-index="5">-</td>
</tr>
<tr>
<th>6</th>
<td class="accuracy" data-row-index="6">-</td>
</tr>
<tr>
<th>7</th>
<td class="accuracy" data-row-index="7">-</td>
</tr>
<tr>
<th>8</th>
<td class="accuracy" data-row-index="8">-</td>
</tr>
<tr>
<th>9</th>
<td class="accuracy" data-row-index="9">-</td>
</tr>
</tbody>
</table>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/signature_pad/1.5.3/signature_pad.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.8.0"></script>
<script src="https://docs.opencv.org/3.4/opencv.js" type="text/javascript"></script>
JavaScript
// init SignaturePad
const drawElement = document.getElementById('drawing-pad');
const signaturePad = new SignaturePad(drawElement, {
minWidth: 6,
maxWidth: 6,
penColor: 'white',
backgroundColor: 'black',
});
// load pre-trained model
let model;
const model_path = '../../js/TensorFlowJs/MNIST/tf-model/model.json'
tf.loadModel(model_path)
.then(function(pretrainedModel){
model = pretrainedModel;
document.getElementById('predict-button').innerHTML = "predict";
document.getElementById('loading').style.disabled="true";
});
function getImageData() {
// grayscale
const src = cv.imread(drawElement);
let dst_gray = new cv.Mat();
cv.cvtColor(src, dst_gray, cv.COLOR_RGBA2GRAY, 0);
// resize
let dst_resized = new cv.Mat();
let dsize = new cv.Size(28, 28);
cv.resize(dst_gray, dst_resized, dsize, 0, 0, cv.INTER_AREA);
cv.imshow('hidden-pad', dst_resized);
const imageData = document.getElementById('hidden-pad').getContext('2d').getImageData(0, 0, 28, 28);
src.delete();
dst_gray.delete();
dst_resized.delete();
return imageData;
}
function getAccuracyScores(imageData) {
// メモリリークの心配がなくなる。自動的にメモリを解放。
const score = tf.tidy(function () {
const channels = 1;
let input = tf.fromPixels(imageData, channels);
input = tf.cast(input, 'float32').div(tf.scalar(255));
input = input.expandDims();
return model.predict(input).dataSync();
});
return score;
}
function prediction() {
const imageData = getImageData();
const accuracyScores = getAccuracyScores(imageData);
const maxAccuracy = accuracyScores.indexOf(Math.max.apply(null, accuracyScores));
const elements = document.querySelectorAll(".accuracy");
elements.forEach(function (el){
el.parentNode.classList.remove('is-selected');
const rowIndex = Number(el.dataset.rowIndex);
if (maxAccuracy === rowIndex) {
el.parentNode.classList.add('is-selected');
}
el.innerText = accuracyScores[rowIndex];
})
}
function reset() {
signaturePad.clear();
let elements = document.querySelectorAll(".accuracy");
elements.forEach(function (el){
el.parentNode.classList.remove('is-selected');
el.innerText = '-';
})
}
css
.mnist {
padding: 10px;
width: 100%;
overflow: hidden;
}
.drawing {
float: left;
width: 50%;
text-align: center;
}
.result {
float: right;
width: 50%;
}
.predict{
padding: 10px;
background-color: #80160e;
}
.reset {
padding: 10px;
background-color: #c8c8a0;
}
.is-selected {
background-color: #80160e;
color: white;
}
@media only screen and (max-width: 1200px) {
.drawing {
width: 100%;
text-align: center;
}
.result {
width: 100%;
}
}
@media only screen and (max-width: 760px) {
.drawing {
float: left;
width: 50%;
text-align: center;
}
.result {
float: right;
width: 50%;
}
}
@media only screen and (max-width: 640px) {
.drawing {
width: 100%;
text-align: center;
}
.result {
width: 100%;
}
}