MCPcopy Index your code
hub / github.com/google-deepmind/graph_nets / stop_gradient

Function stop_gradient

graph_nets/utils_tf.py:419–461  ·  view source on GitHub ↗

Stops the gradient flow through a graph. Args: graph: An instance of `graphs.GraphsTuple` containing `Tensor`s. stop_edges: (bool, default=True) indicates whether to stop gradients for the edges. stop_nodes: (bool, default=True) indicates whether to stop gradients for the

(graph,
                  stop_edges=True,
                  stop_nodes=True,
                  stop_globals=True,
                  name="graph_stop_gradient")

Source from the content-addressed store, hash-verified

417
418
419def stop_gradient(graph,
420 stop_edges=True,
421 stop_nodes=True,
422 stop_globals=True,
423 name="graph_stop_gradient"):
424 """Stops the gradient flow through a graph.
425
426 Args:
427 graph: An instance of `graphs.GraphsTuple` containing `Tensor`s.
428 stop_edges: (bool, default=True) indicates whether to stop gradients for
429 the edges.
430 stop_nodes: (bool, default=True) indicates whether to stop gradients for
431 the nodes.
432 stop_globals: (bool, default=True) indicates whether to stop gradients for
433 the globals.
434 name: (string, optional) A name for the operation.
435
436 Returns:
437 GraphsTuple after stopping the gradients according to the provided
438 parameters.
439
440 Raises:
441 ValueError: If attempting to stop gradients through a field which has a
442 `None` value in `graph`.
443 """
444
445 base_err_msg = "Cannot stop gradient through {0} if {0} are None"
446 fields_to_stop = []
447 if stop_globals:
448 if graph.globals is None:
449 raise ValueError(base_err_msg.format(GLOBALS))
450 fields_to_stop.append(GLOBALS)
451 if stop_nodes:
452 if graph.nodes is None:
453 raise ValueError(base_err_msg.format(NODES))
454 fields_to_stop.append(NODES)
455 if stop_edges:
456 if graph.edges is None:
457 raise ValueError(base_err_msg.format(EDGES))
458 fields_to_stop.append(EDGES)
459
460 with tf.name_scope(name):
461 return graph.map(tf.stop_gradient, fields_to_stop)
462
463
464def identity(graph, name="graph_identity"):

Callers

nothing calls this directly

Calls 1

mapMethod · 0.80

Tested by

no test coverage detected