MCPcopy
hub / github.com/tensorflow/tfjs / fitCallbacks

Function fitCallbacks

tfjs-vis/src/show/history.ts:254–307  ·  view source on GitHub ↗
(
    container: Drawable, metrics: string[],
    opts: FitCallbackOptions = {})

Source from the content-addressed store, hash-verified

252 * 'show'}
253 */
254export function fitCallbacks(
255 container: Drawable, metrics: string[],
256 opts: FitCallbackOptions = {}): FitCallbackHandlers {
257 const accumulators: FitCallbackLogs = {};
258 const callbackNames = opts.callbacks || ['onEpochEnd', 'onBatchEnd'];
259 const drawArea = getDrawArea(container);
260
261 const historyOpts = Object.assign({}, opts);
262 delete historyOpts.callbacks;
263 function makeCallbackFor(callbackName: string) {
264 return async (_: number, log: Logs) => {
265 // Set a nicer x axis name where possible
266 if ((/batch/i).test(callbackName)) {
267 historyOpts.xLabel = 'Batch';
268 } else if ((/epoch/i).test(callbackName)) {
269 historyOpts.xLabel = 'Epoch';
270 }
271
272 // Because of how the _ (iteration) numbers are given in the layers api
273 // we have to store each metric for each callback in different arrays else
274 // we cannot get accurate 'global' batch numbers for onBatchEnd.
275
276 // However at render time we want to be able to combine metrics for a
277 // given callback. So here we make a nested list of metrics, the first
278 // level are arrays for each callback, the second level contains arrays
279 // (of logs) for each metric within that callback.
280
281 const metricLogs: Logs[][] = [];
282 const presentMetrics: string[] = [];
283 for (const metric of metrics) {
284 // not all logs have all kinds of metrics.
285 if (log[metric] != null) {
286 presentMetrics.push(metric);
287
288 const accumulator =
289 getAccumulator(accumulators, callbackName, metric);
290 accumulator.push({[metric]: log[metric]});
291 metricLogs.push(accumulator);
292 }
293 }
294
295 const subContainer =
296 subSurface(drawArea, callbackName, {title: callbackName});
297 history(subContainer, metricLogs, presentMetrics, historyOpts);
298 await nextFrame();
299 };
300 }
301
302 const callbacks: FitCallbackHandlers = {};
303 callbackNames.forEach((name: string) => {
304 callbacks[name] = makeCallbackFor(name);
305 });
306 return callbacks;
307}
308
309interface FitCallbackHandlers {
310 [key: string]: (iteration: number, log: Logs) => Promise<void>;

Callers 1

history_test.tsFile · 0.90

Calls 3

getDrawAreaFunction · 0.90
makeCallbackForFunction · 0.85
assignMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…