Skip to content

Commit 1d8c91b

Browse files
fix: Fix serialization from json and provide test
1 parent 8d6dd01 commit 1d8c91b

7 files changed

Lines changed: 76 additions & 64 deletions

File tree

browser.js

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* license: MIT (http://opensource.org/licenses/MIT)
77
* author: Heather Arthur <fayearthur@gmail.com>
88
* homepage: https://github.com/brainjs/brain.js#readme
9-
* version: 1.6.0
9+
* version: 1.6.1
1010
*
1111
* acorn:
1212
* license: MIT (http://opensource.org/licenses/MIT)
@@ -1538,11 +1538,12 @@ var NeuralNetwork = function () {
15381538
}, {
15391539
key: 'updateTrainingOptions',
15401540
value: function updateTrainingOptions(options) {
1541-
var _this2 = this;
1542-
1543-
Object.keys(this.constructor.trainDefaults).forEach(function (p) {
1544-
return _this2.trainOpts[p] = options.hasOwnProperty(p) ? options[p] : _this2.trainOpts[p];
1545-
});
1541+
var trainDefaults = this.constructor.trainDefaults;
1542+
for (var p in trainDefaults) {
1543+
if (!trainDefaults.hasOwnProperty(p)) continue;
1544+
if (!options.hasOwnProperty(p)) continue;
1545+
this.trainOpts[p] = options[p];
1546+
}
15461547
this.validateTrainingOptions(this.trainOpts);
15471548
this.setLogMethod(options.log || this.trainOpts.log);
15481549
this.activation = options.activation || this.activation;
@@ -1585,11 +1586,13 @@ var NeuralNetwork = function () {
15851586
return typeof val === 'number' && val > 0;
15861587
}
15871588
};
1588-
Object.keys(this.constructor.trainDefaults).forEach(function (key) {
1589-
if (validations.hasOwnProperty(key) && !validations[key](options[key])) {
1590-
throw new Error('[' + key + ', ' + options[key] + '] is out of normal training range, your network will probably not train.');
1589+
for (var p in validations) {
1590+
if (!validations.hasOwnProperty(p)) continue;
1591+
if (!options.hasOwnProperty(p)) continue;
1592+
if (!validations[p](options[p])) {
1593+
throw new Error('[' + p + ', ' + options[p] + '] is out of normal training range, your network will probably not train.');
15911594
}
1592-
});
1595+
}
15931596
}
15941597

15951598
/**
@@ -1601,11 +1604,11 @@ var NeuralNetwork = function () {
16011604
}, {
16021605
key: 'getTrainOptsJSON',
16031606
value: function getTrainOptsJSON() {
1604-
var _this3 = this;
1607+
var _this2 = this;
16051608

16061609
return Object.keys(this.constructor.trainDefaults).reduce(function (opts, opt) {
1607-
if (opt === 'timeout' && _this3.trainOpts[opt] === Infinity) return opts;
1608-
if (_this3.trainOpts[opt]) opts[opt] = _this3.trainOpts[opt];
1610+
if (opt === 'timeout' && _this2.trainOpts[opt] === Infinity) return opts;
1611+
if (_this2.trainOpts[opt]) opts[opt] = _this2.trainOpts[opt];
16091612
if (opt === 'log') opts.log = typeof opts.log === 'function';
16101613
return opts;
16111614
}, {});
@@ -1762,7 +1765,7 @@ var NeuralNetwork = function () {
17621765
}, {
17631766
key: 'trainAsync',
17641767
value: function trainAsync(data) {
1765-
var _this4 = this;
1768+
var _this3 = this;
17661769

17671770
var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
17681771

@@ -1778,10 +1781,10 @@ var NeuralNetwork = function () {
17781781

17791782
return new Promise(function (resolve, reject) {
17801783
try {
1781-
var thawedTrain = new _thaw2.default(new Array(_this4.trainOpts.iterations), {
1784+
var thawedTrain = new _thaw2.default(new Array(_this3.trainOpts.iterations), {
17821785
delay: true,
17831786
each: function each() {
1784-
return _this4.trainingTick(data, status, endTime) || thawedTrain.stop();
1787+
return _this3.trainingTick(data, status, endTime) || thawedTrain.stop();
17851788
},
17861789
done: function done() {
17871790
return resolve(status);
@@ -2115,7 +2118,7 @@ var NeuralNetwork = function () {
21152118
}, {
21162119
key: 'test',
21172120
value: function test(data) {
2118-
var _this5 = this;
2121+
var _this4 = this;
21192122

21202123
data = this.formatData(data);
21212124
// for binary classification problems with one output node
@@ -2133,9 +2136,9 @@ var NeuralNetwork = function () {
21332136
var trueNeg = 0;
21342137

21352138
var _loop = function _loop(i) {
2136-
var output = _this5.runInput(data[i].input);
2139+
var output = _this4.runInput(data[i].input);
21372140
var target = data[i].output;
2138-
var actual = output[0] > _this5.binaryThresh ? 1 : 0;
2141+
var actual = output[0] > _this4.binaryThresh ? 1 : 0;
21392142
var expected = target[0];
21402143

21412144
if (actual !== expected) {
@@ -2182,7 +2185,7 @@ var NeuralNetwork = function () {
21822185
}
21832186

21842187
var _loop2 = function _loop2(i) {
2185-
var output = _this5.runInput(data[i].input);
2188+
var output = _this4.runInput(data[i].input);
21862189
var target = data[i].output;
21872190
var actual = output.indexOf((0, _max2.default)(output));
21882191
var expected = target.indexOf((0, _max2.default)(target));
@@ -2302,6 +2305,7 @@ var NeuralNetwork = function () {
23022305
}, {
23032306
key: 'fromJSON',
23042307
value: function fromJSON(json) {
2308+
Object.assign(this, this.constructor.defaults, json);
23052309
this.sizes = json.sizes;
23062310
this.initialize();
23072311

@@ -2326,7 +2330,6 @@ var NeuralNetwork = function () {
23262330
if (json.hasOwnProperty('trainOpts')) {
23272331
this.updateTrainingOptions(json.trainOpts);
23282332
}
2329-
this.setActivation(this.activation || 'sigmoid');
23302333
return this;
23312334
}
23322335

@@ -2397,15 +2400,15 @@ var NeuralNetwork = function () {
23972400
}, {
23982401
key: 'isRunnable',
23992402
get: function get() {
2400-
var _this6 = this;
2403+
var _this5 = this;
24012404

24022405
if (!this.runInput) {
24032406
console.error('Activation function has not been initialized, did you run train()?');
24042407
return false;
24052408
}
24062409

24072410
var checkFns = ['sizes', 'outputLayer', 'biases', 'weights', 'outputs', 'deltas', 'changes', 'errors'].filter(function (c) {
2408-
return _this6[c] === null;
2411+
return _this5[c] === null;
24092412
});
24102413

24112414
if (checkFns.length > 0) {

browser.min.js

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js

Lines changed: 25 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "brain.js",
33
"description": "Neural network library",
4-
"version": "1.6.0",
4+
"version": "1.6.1",
55
"author": "Heather Arthur <fayearthur@gmail.com>",
66
"repository": {
77
"type": "git",

src/neural-network.js

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ export default class NeuralNetwork {
116116
* @param activation supported inputs: 'sigmoid', 'relu', 'leaky-relu', 'tanh'
117117
*/
118118
setActivation(activation) {
119-
this.activation = (activation) ? activation : this.activation;
119+
this.activation = activation ? activation : this.activation;
120120
switch (this.activation) {
121121
case 'sigmoid':
122122
this.runInput = this.runInput || this._runInputSigmoid;
@@ -305,7 +305,12 @@ export default class NeuralNetwork {
305305
* activation: 'sigmoid', 'relu', 'leaky-relu', 'tanh'
306306
*/
307307
updateTrainingOptions(options) {
308-
Object.keys(this.constructor.trainDefaults).forEach(p => this.trainOpts[p] = (options.hasOwnProperty(p)) ? options[p] : this.trainOpts[p]);
308+
const trainDefaults = this.constructor.trainDefaults;
309+
for (const p in trainDefaults) {
310+
if (!trainDefaults.hasOwnProperty(p)) continue;
311+
if (!options.hasOwnProperty(p)) continue;
312+
this.trainOpts[p] = options[p];
313+
}
309314
this.validateTrainingOptions(this.trainOpts);
310315
this.setLogMethod(options.log || this.trainOpts.log);
311316
this.activation = options.activation || this.activation;
@@ -327,11 +332,13 @@ export default class NeuralNetwork {
327332
callbackPeriod: (val) => { return typeof val === 'number' && val > 0; },
328333
timeout: (val) => { return typeof val === 'number' && val > 0 }
329334
};
330-
Object.keys(this.constructor.trainDefaults).forEach(key => {
331-
if (validations.hasOwnProperty(key) && !validations[key](options[key])) {
332-
throw new Error(`[${key}, ${options[key]}] is out of normal training range, your network will probably not train.`);
335+
for (const p in validations) {
336+
if (!validations.hasOwnProperty(p)) continue;
337+
if (!options.hasOwnProperty(p)) continue;
338+
if (!validations[p](options[p])) {
339+
throw new Error(`[${p}, ${options[p]}] is out of normal training range, your network will probably not train.`);
333340
}
334-
});
341+
}
335342
}
336343

337344
/**
@@ -956,6 +963,7 @@ export default class NeuralNetwork {
956963
* @returns {NeuralNetwork}
957964
*/
958965
fromJSON(json) {
966+
Object.assign(this, this.constructor.defaults, json);
959967
this.sizes = json.sizes;
960968
this.initialize();
961969

@@ -981,7 +989,6 @@ export default class NeuralNetwork {
981989
if (json.hasOwnProperty('trainOpts')) {
982990
this.updateTrainingOptions(json.trainOpts);
983991
}
984-
this.setActivation(this.activation || 'sigmoid');
985992
return this;
986993
}
987994

test/base/json.js

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ import assert from 'assert';
22
import NeuralNetwork from './../../src/neural-network';
33

44
describe('JSON', () => {
5-
const originalNet = new NeuralNetwork();
6-
7-
let trainingOpts = {
5+
const originalNet = new NeuralNetwork({ activation: 'leaky-relu' });
6+
const trainingOpts = {
87
iterations: 200,
98
errorThresh: 0.05,
109
log: () => {},
@@ -200,7 +199,7 @@ describe('JSON', () => {
200199

201200

202201
describe('default net json', () => {
203-
const originalNet = new NeuralNetwork();
202+
const originalNet = new NeuralNetwork({ activation: 'leaky-relu' });
204203

205204
originalNet.train([
206205
{
@@ -210,7 +209,7 @@ describe('default net json', () => {
210209
input: {'0': Math.random(), b: Math.random()},
211210
output: {c: Math.random(), '0': Math.random()}
212211
}
213-
]);
212+
], { timeout: 4 });
214213

215214
const serialized = originalNet.toJSON();
216215
const serializedNet = new NeuralNetwork()
@@ -276,4 +275,4 @@ describe('default net json', () => {
276275
net.fromJSON({ sizes: [], layers: [] });
277276
assert(net.activation === 'sigmoid');
278277
})
279-
})
278+
})

0 commit comments

Comments
 (0)