(quantizationDtype: 'uint8'|'uint16')
| 440 | }); |
| 441 | |
| 442 | const quantizationTest = async (quantizationDtype: 'uint8'|'uint16') => { |
| 443 | const arrayType = quantizationDtype === 'uint8' ? Uint8Array : Uint16Array; |
| 444 | setupFakeWeightFiles( |
| 445 | {'./weightfile0': new arrayType([0, 48, 255, 0, 48, 255])}); |
| 446 | |
| 447 | const manifest: WeightsManifestConfig = [{ |
| 448 | 'paths': ['weightfile0'], |
| 449 | 'weights': [ |
| 450 | { |
| 451 | 'name': 'weight0', |
| 452 | 'dtype': 'float32', |
| 453 | 'shape': [3], |
| 454 | 'quantization': {'min': -1, 'scale': 0.1, 'dtype': quantizationDtype} |
| 455 | }, |
| 456 | { |
| 457 | 'name': 'weight1', |
| 458 | 'dtype': 'int32', |
| 459 | 'shape': [3], |
| 460 | 'quantization': {'min': -1, 'scale': 0.1, 'dtype': quantizationDtype} |
| 461 | } |
| 462 | ] |
| 463 | }]; |
| 464 | |
| 465 | const weightsNamesToFetch = ['weight0', 'weight1']; |
| 466 | const weights = |
| 467 | await tf.io.loadWeights(manifest, './', weightsNamesToFetch); |
| 468 | expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); |
| 469 | |
| 470 | const weightNames = Object.keys(weights); |
| 471 | expect(weightNames.length).toEqual(weightsNamesToFetch.length); |
| 472 | |
| 473 | const weight0 = weights['weight0']; |
| 474 | expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); |
| 475 | expect(weight0.shape).toEqual([3]); |
| 476 | expect(weight0.dtype).toEqual('float32'); |
| 477 | |
| 478 | const weight1 = weights['weight1']; |
| 479 | expectArraysEqual(await weight1.data(), [-1, 4, 25]); |
| 480 | expect(weight1.shape).toEqual([3]); |
| 481 | expect(weight1.dtype).toEqual('int32'); |
| 482 | }; |
| 483 | |
| 484 | it('quantized weights (uint8)', async () => { |
| 485 | await quantizationTest('uint8'); |
no test coverage detected
searching dependent graphs…