Stops the gradient flow through a graph. Args: graph: An instance of `graphs.GraphsTuple` containing `Tensor`s. stop_edges: (bool, default=True) indicates whether to stop gradients for the edges. stop_nodes: (bool, default=True) indicates whether to stop gradients for the
(graph,
stop_edges=True,
stop_nodes=True,
stop_globals=True,
name="graph_stop_gradient")
| 417 | |
| 418 | |
| 419 | def stop_gradient(graph, |
| 420 | stop_edges=True, |
| 421 | stop_nodes=True, |
| 422 | stop_globals=True, |
| 423 | name="graph_stop_gradient"): |
| 424 | """Stops the gradient flow through a graph. |
| 425 | |
| 426 | Args: |
| 427 | graph: An instance of `graphs.GraphsTuple` containing `Tensor`s. |
| 428 | stop_edges: (bool, default=True) indicates whether to stop gradients for |
| 429 | the edges. |
| 430 | stop_nodes: (bool, default=True) indicates whether to stop gradients for |
| 431 | the nodes. |
| 432 | stop_globals: (bool, default=True) indicates whether to stop gradients for |
| 433 | the globals. |
| 434 | name: (string, optional) A name for the operation. |
| 435 | |
| 436 | Returns: |
| 437 | GraphsTuple after stopping the gradients according to the provided |
| 438 | parameters. |
| 439 | |
| 440 | Raises: |
| 441 | ValueError: If attempting to stop gradients through a field which has a |
| 442 | `None` value in `graph`. |
| 443 | """ |
| 444 | |
| 445 | base_err_msg = "Cannot stop gradient through {0} if {0} are None" |
| 446 | fields_to_stop = [] |
| 447 | if stop_globals: |
| 448 | if graph.globals is None: |
| 449 | raise ValueError(base_err_msg.format(GLOBALS)) |
| 450 | fields_to_stop.append(GLOBALS) |
| 451 | if stop_nodes: |
| 452 | if graph.nodes is None: |
| 453 | raise ValueError(base_err_msg.format(NODES)) |
| 454 | fields_to_stop.append(NODES) |
| 455 | if stop_edges: |
| 456 | if graph.edges is None: |
| 457 | raise ValueError(base_err_msg.format(EDGES)) |
| 458 | fields_to_stop.append(EDGES) |
| 459 | |
| 460 | with tf.name_scope(name): |
| 461 | return graph.map(tf.stop_gradient, fields_to_stop) |
| 462 | |
| 463 | |
| 464 | def identity(graph, name="graph_identity"): |