MCPcopy Index your code
hub / github.com/pytorch/pytorch / CompareNets

Method CompareNets

caffe2/python/gradient_checker.py:74–118  ·  view source on GitHub ↗
(nets, outputs, outputs_with_grad_ids,
                    inputs_with_grads, input_values=None,
                    threshold=0.0000001, print_net_images=False)

Source from the content-addressed store, hash-verified

72class NetGradientChecker:
73 @staticmethod
74 def CompareNets(nets, outputs, outputs_with_grad_ids,
75 inputs_with_grads, input_values=None,
76 threshold=0.0000001, print_net_images=False):
77 def _get_output_with_grad_names(net_outputs):
78 return [net_outputs[i] for i in outputs_with_grad_ids]
79
80 if print_net_images:
81 for i, net in enumerate(nets):
82 png = net_drawer.GetPydotGraph(net).create_png()
83 with open("caffe2_net_forward_" + str(i) + net.Name() + ".png",
84 'wb') \
85 as f:
86 f.write(png)
87
88 results = [
89 _get_grad(net, net_outputs,
90 _get_output_with_grad_names(net_outputs),
91 input_values, inputs_with_grads)
92 for net, net_outputs in zip(nets, outputs)
93 ]
94
95 if print_net_images:
96 _, _, backward_nets = zip(*results)
97 for i, net in enumerate(backward_nets):
98 png = net_drawer.GetPydotGraph(net).create_png()
99 with open("caffe2_net_" + str(i) + net.Name() + ".png", 'wb') \
100 as f:
101 f.write(png)
102
103 first_net_results, first_net_grads, _ = results[0]
104 for net_results, net_grads, _ in results[1:]:
105 assert len(net_results) == len(first_net_results)
106 for idx, ((blob1, blob_value1), (blob2, blob_value2)) in enumerate(
107 zip(first_net_results, net_results)):
108 _assert_close(
109 blob_value1, blob_value2, threshold,
110 err_msg="Different forward pass results for output id {}. "
111 "Corresponding output blobs: {} and {}".format(
112 idx, blob1, blob2))
113
114 assert net_grads.keys() == first_net_grads.keys()
115 for blob, blob_grad_value in net_grads.items():
116 _assert_close(
117 first_net_grads[blob], blob_grad_value, threshold,
118 err_msg="Different gradients for input {}".format(blob))
119
120 @staticmethod
121 def Check(net, outputs_with_grad, input_values,

Callers 4

test_net_comparisonMethod · 0.80
test_unroll_mulMethod · 0.80
test_unroll_lstmMethod · 0.80
test_unroll_attentionMethod · 0.80

Calls 7

_get_gradFunction · 0.85
_assert_closeFunction · 0.85
NameMethod · 0.45
writeMethod · 0.45
formatMethod · 0.45
keysMethod · 0.45
itemsMethod · 0.45

Tested by 4

test_net_comparisonMethod · 0.64
test_unroll_mulMethod · 0.64
test_unroll_lstmMethod · 0.64
test_unroll_attentionMethod · 0.64