@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
440 lines • 20.1 kB
JavaScript
;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var tf = require("../index");
var jasmine_util_1 = require("../jasmine_util");
var indexed_db_1 = require("./indexed_db");
var local_storage_1 = require("./local_storage");
// Disabled for non-Chrome browsers due to:
// https://github.com/tensorflow/tfjs/issues/427
jasmine_util_1.describeWithFlags('ModelManagement', jasmine_util_1.CHROME_ENVS, function () {
// Test data.
var modelTopology1 = {
'class_name': 'Sequential',
'keras_version': '2.1.4',
'config': [{
'class_name': 'Dense',
'config': {
'kernel_initializer': {
'class_name': 'VarianceScaling',
'config': {
'distribution': 'uniform',
'scale': 1.0,
'seed': null,
'mode': 'fan_avg'
}
},
'name': 'dense',
'kernel_constraint': null,
'bias_regularizer': null,
'bias_constraint': null,
'dtype': 'float32',
'activation': 'linear',
'trainable': true,
'kernel_regularizer': null,
'bias_initializer': { 'class_name': 'Zeros', 'config': {} },
'units': 1,
'batch_input_shape': [null, 3],
'use_bias': true,
'activity_regularizer': null
}
}],
'backend': 'tensorflow'
};
var weightSpecs1 = [
{
name: 'dense/kernel',
shape: [3, 1],
dtype: 'float32',
},
{
name: 'dense/bias',
shape: [1],
dtype: 'float32',
}
];
var weightData1 = new ArrayBuffer(16);
var artifacts1 = {
modelTopology: modelTopology1,
weightSpecs: weightSpecs1,
weightData: weightData1,
};
beforeEach(function (done) {
local_storage_1.purgeLocalStorageArtifacts();
indexed_db_1.deleteDatabase().then(function () {
done();
});
});
afterEach(function (done) {
local_storage_1.purgeLocalStorageArtifacts();
indexed_db_1.deleteDatabase().then(function () {
done();
});
});
// TODO(cais): Reenable this test once we fix
// https://github.com/tensorflow/tfjs/issues/1198
// tslint:disable-next-line:ban
xit('List models: 0 result', function (done) {
// Before any model is saved, listModels should return empty result.
tf.io.listModels()
.then(function (out) {
expect(out).toEqual({});
done();
})
.catch(function (err) { return done.fail(err.stack); });
});
// TODO(cais): Reenable this test once we fix
// https://github.com/tensorflow/tfjs/issues/1198
// tslint:disable-next-line:ban
xit('List models: 1 result', function (done) {
var url = 'localstorage://baz/QuxModel';
var handler = tf.io.getSaveHandlers(url)[0];
handler.save(artifacts1)
.then(function (saveResult) {
// After successful saving, there should be one model.
tf.io.listModels()
.then(function (out) {
expect(Object.keys(out).length).toEqual(1);
expect(out[url].modelTopologyType)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyType);
expect(out[url].modelTopologyBytes)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes);
expect(out[url].weightSpecsBytes)
.toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes);
expect(out[url].weightDataBytes)
.toEqual(saveResult.modelArtifactsInfo.weightDataBytes);
done();
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
});
// TODO(cais): Reenable this test once we fix
// https://github.com/tensorflow/tfjs/issues/1198
// tslint:disable-next-line:ban
xit('Manager: List models: 2 results in 2 mediums', function (done) {
var url1 = 'localstorage://QuxModel';
var url2 = 'indexeddb://QuxModel';
// First, save a model in Local Storage.
var handler1 = tf.io.getSaveHandlers(url1)[0];
handler1.save(artifacts1)
.then(function (saveResult1) {
// Then, save the model in IndexedDB.
var handler2 = tf.io.getSaveHandlers(url2)[0];
handler2.save(artifacts1)
.then(function (saveResult2) {
// After successful saving, there should be two models.
tf.io.listModels()
.then(function (out) {
expect(Object.keys(out).length).toEqual(2);
expect(out[url1].modelTopologyType)
.toEqual(saveResult1.modelArtifactsInfo.modelTopologyType);
expect(out[url1].modelTopologyBytes)
.toEqual(saveResult1.modelArtifactsInfo
.modelTopologyBytes);
expect(out[url1].weightSpecsBytes)
.toEqual(saveResult1.modelArtifactsInfo.weightSpecsBytes);
expect(out[url1].weightDataBytes)
.toEqual(saveResult1.modelArtifactsInfo.weightDataBytes);
expect(out[url2].modelTopologyType)
.toEqual(saveResult2.modelArtifactsInfo.modelTopologyType);
expect(out[url2].modelTopologyBytes)
.toEqual(saveResult2.modelArtifactsInfo
.modelTopologyBytes);
expect(out[url2].weightSpecsBytes)
.toEqual(saveResult2.modelArtifactsInfo.weightSpecsBytes);
expect(out[url2].weightDataBytes)
.toEqual(saveResult2.modelArtifactsInfo.weightDataBytes);
done();
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
});
// TODO(cais): Reenable this test once we fix
// https://github.com/tensorflow/tfjs/issues/1198
// tslint:disable-next-line:ban
xit('Successful removeModel', function (done) {
// First, save a model.
var handler1 = tf.io.getSaveHandlers('localstorage://QuxModel')[0];
handler1.save(artifacts1)
.then(function (saveResult1) {
// Then, save the model under another path.
var handler2 = tf.io.getSaveHandlers('indexeddb://repeat/QuxModel')[0];
handler2.save(artifacts1)
.then(function (saveResult2) {
// After successful saving, delete the first save, and then
// `listModel` should give only one result.
// Delete a model specified with a path that includes the
// indexeddb:// scheme prefix should work.
tf.io.removeModel('indexeddb://repeat/QuxModel')
.then(function (deletedInfo) {
tf.io.listModels()
.then(function (out) {
expect(Object.keys(out)).toEqual([
'localstorage://QuxModel'
]);
tf.io.removeModel('localstorage://QuxModel')
.then(function (out) {
// The delete the remaining model.
tf.io.listModels()
.then(function (out) {
expect(Object.keys(out)).toEqual([]);
done();
})
.catch(function (err) { return done.fail(err); });
})
.catch(function (err) { return done.fail(err); });
})
.catch(function (err) { return done.fail(err); });
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
});
// TODO(cais): Reenable this test once we fix
// https://github.com/tensorflow/tfjs/issues/1198
// tslint:disable-next-line:ban
xit('Successful copyModel between mediums', function (done) {
var url1 = 'localstorage://a1/FooModel';
var url2 = 'indexeddb://a1/FooModel';
// First, save a model.
var handler1 = tf.io.getSaveHandlers(url1)[0];
handler1.save(artifacts1)
.then(function (saveResult) {
// Once model is saved, copy the model to another path.
tf.io.copyModel(url1, url2)
.then(function (modelInfo) {
tf.io.listModels().then(function (out) {
expect(Object.keys(out).length).toEqual(2);
expect(out[url1].modelTopologyType)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyType);
expect(out[url1].modelTopologyBytes)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes);
expect(out[url1].weightSpecsBytes)
.toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes);
expect(out[url1].weightDataBytes)
.toEqual(saveResult.modelArtifactsInfo.weightDataBytes);
expect(out[url2].modelTopologyType)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyType);
expect(out[url2].modelTopologyBytes)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes);
expect(out[url2].weightSpecsBytes)
.toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes);
expect(out[url2].weightDataBytes)
.toEqual(saveResult.modelArtifactsInfo.weightDataBytes);
// Load the copy and verify the content.
var handler2 = tf.io.getLoadHandlers(url2)[0];
handler2.load()
.then(function (loaded) {
expect(loaded.modelTopology).toEqual(modelTopology1);
expect(loaded.weightSpecs).toEqual(weightSpecs1);
expect(new Uint8Array(loaded.weightData))
.toEqual(new Uint8Array(weightData1));
done();
})
.catch(function (err) { return done.fail(err.stack); });
});
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
});
// TODO(cais): Reenable this test once we fix
// https://github.com/tensorflow/tfjs/issues/1198
// tslint:disable-next-line:ban
xit('Successful moveModel between mediums', function (done) {
var url1 = 'localstorage://a1/FooModel';
var url2 = 'indexeddb://a1/FooModel';
// First, save a model.
var handler1 = tf.io.getSaveHandlers(url1)[0];
handler1.save(artifacts1)
.then(function (saveResult) {
// Once model is saved, move the model to another path.
tf.io.moveModel(url1, url2)
.then(function (modelInfo) {
tf.io.listModels().then(function (out) {
expect(Object.keys(out)).toEqual([url2]);
expect(out[url2].modelTopologyType)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyType);
expect(out[url2].modelTopologyBytes)
.toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes);
expect(out[url2].weightSpecsBytes)
.toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes);
expect(out[url2].weightDataBytes)
.toEqual(saveResult.modelArtifactsInfo.weightDataBytes);
// Load the copy and verify the content.
var handler2 = tf.io.getLoadHandlers(url2)[0];
handler2.load()
.then(function (loaded) {
expect(loaded.modelTopology).toEqual(modelTopology1);
expect(loaded.weightSpecs).toEqual(weightSpecs1);
expect(new Uint8Array(loaded.weightData))
.toEqual(new Uint8Array(weightData1));
done();
})
.catch(function (err) {
done.fail(err.stack);
});
});
})
.catch(function (err) { return done.fail(err.stack); });
})
.catch(function (err) { return done.fail(err.stack); });
});
it('Failed copyModel to invalid source URL', function (done) {
var url1 = 'invalidurl';
var url2 = 'localstorage://a1/FooModel';
tf.io.copyModel(url1, url2)
.then(function (out) {
done.fail('Copying from invalid URL succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Copying failed because no load handler is found for ' +
'source URL invalidurl.');
done();
});
});
it('Failed copyModel to invalid destination URL', function (done) {
var url1 = 'localstorage://a1/FooModel';
var url2 = 'invalidurl';
// First, save a model.
var handler1 = tf.io.getSaveHandlers(url1)[0];
handler1.save(artifacts1)
.then(function (saveResult) {
// Once model is saved, copy the model to another path.
tf.io.copyModel(url1, url2)
.then(function (out) {
done.fail('Copying to invalid URL succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Copying failed because no save handler is found for ' +
'destination URL invalidurl.');
done();
});
})
.catch(function (err) { return done.fail(err.stack); });
});
it('Failed moveModel to invalid destination URL', function (done) {
var url1 = 'localstorage://a1/FooModel';
var url2 = 'invalidurl';
// First, save a model.
var handler1 = tf.io.getSaveHandlers(url1)[0];
handler1.save(artifacts1)
.then(function (saveResult) {
// Once model is saved, copy the model to an invalid path, which
// should fail.
tf.io.moveModel(url1, url2)
.then(function (out) {
done.fail('Copying to invalid URL succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Copying failed because no save handler is found for ' +
'destination URL invalidurl.');
// Verify that the source has not been removed.
tf.io.listModels()
.then(function (out) {
expect(Object.keys(out)).toEqual([url1]);
done();
})
.catch(function (err) { return done.fail(err.stack); });
});
})
.catch(function (err) { return done.fail(err.stack); });
});
it('Failed deletedModel: Absent scheme', function (done) {
// Attempt to delete a nonexistent model is expected to fail.
tf.io.removeModel('foo')
.then(function (out) {
done.fail('Removing model with missing scheme succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toMatch(/The url string provided does not contain a scheme/);
expect(err.message.indexOf('localstorage')).toBeGreaterThan(0);
expect(err.message.indexOf('indexeddb')).toBeGreaterThan(0);
done();
});
});
it('Failed deletedModel: Invalid scheme', function (done) {
// Attempt to delete a nonexistent model is expected to fail.
tf.io.removeModel('invalidscheme://foo')
.then(function (out) {
done.fail('Removing nonexistent model succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Cannot find model manager for scheme \'invalidscheme\'');
done();
});
});
it('Failed deletedModel: Nonexistent model', function (done) {
// Attempt to delete a nonexistent model is expected to fail.
tf.io.removeModel('indexeddb://nonexistent')
.then(function (out) {
done.fail('Removing nonexistent model succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Cannot find model with path \'nonexistent\' in IndexedDB.');
done();
});
});
it('Failed copyModel', function (done) {
// Attempt to copy a nonexistent model should fail.
tf.io.copyModel('indexeddb://nonexistent', 'indexeddb://destination')
.then(function (out) {
done.fail('Copying nonexistent model succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Cannot find model with path \'nonexistent\' in IndexedDB.');
done();
});
});
it('copyModel: Identical oldPath and newPath leads to Error', function (done) {
tf.io.copyModel('a/1', 'a/1')
.then(function (out) {
done.fail('Copying with identical old & new paths succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Old path and new path are the same: \'a/1\'');
done();
});
});
it('moveModel: Identical oldPath and newPath leads to Error', function (done) {
tf.io.moveModel('a/1', 'a/1')
.then(function (out) {
done.fail('Copying with identical old & new paths succeeded unexpectedly.');
})
.catch(function (err) {
expect(err.message)
.toEqual('Old path and new path are the same: \'a/1\'');
done();
});
});
});
//# sourceMappingURL=model_management_test.js.map