MCPcopy
hub / github.com/tensorflow/tfjs / quantizationTest

Function quantizationTest

tfjs-core/src/io/weights_loader_test.ts:442–482  ·  view source on GitHub ↗
(quantizationDtype: 'uint8'|'uint16')

Source from the content-addressed store, hash-verified

440 });
441
442 const quantizationTest = async (quantizationDtype: 'uint8'|'uint16') => {
443 const arrayType = quantizationDtype === 'uint8' ? Uint8Array : Uint16Array;
444 setupFakeWeightFiles(
445 {'./weightfile0': new arrayType([0, 48, 255, 0, 48, 255])});
446
447 const manifest: WeightsManifestConfig = [{
448 'paths': ['weightfile0'],
449 'weights': [
450 {
451 'name': 'weight0',
452 'dtype': 'float32',
453 'shape': [3],
454 'quantization': {'min': -1, 'scale': 0.1, 'dtype': quantizationDtype}
455 },
456 {
457 'name': 'weight1',
458 'dtype': 'int32',
459 'shape': [3],
460 'quantization': {'min': -1, 'scale': 0.1, 'dtype': quantizationDtype}
461 }
462 ]
463 }];
464
465 const weightsNamesToFetch = ['weight0', 'weight1'];
466 const weights =
467 await tf.io.loadWeights(manifest, './', weightsNamesToFetch);
468 expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1);
469
470 const weightNames = Object.keys(weights);
471 expect(weightNames.length).toEqual(weightsNamesToFetch.length);
472
473 const weight0 = weights['weight0'];
474 expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]);
475 expect(weight0.shape).toEqual([3]);
476 expect(weight0.dtype).toEqual('float32');
477
478 const weight1 = weights['weight1'];
479 expectArraysEqual(await weight1.data(), [-1, 4, 25]);
480 expect(weight1.shape).toEqual([3]);
481 expect(weight1.dtype).toEqual('int32');
482 };
483
484 it('quantized weights (uint8)', async () => {
485 await quantizationTest('uint8');

Callers 1

Calls 5

expectArraysCloseFunction · 0.90
expectArraysEqualFunction · 0.90
setupFakeWeightFilesFunction · 0.70
dataMethod · 0.65
loadWeightsMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…