| 72 | class NetGradientChecker: |
| 73 | @staticmethod |
| 74 | def CompareNets(nets, outputs, outputs_with_grad_ids, |
| 75 | inputs_with_grads, input_values=None, |
| 76 | threshold=0.0000001, print_net_images=False): |
| 77 | def _get_output_with_grad_names(net_outputs): |
| 78 | return [net_outputs[i] for i in outputs_with_grad_ids] |
| 79 | |
| 80 | if print_net_images: |
| 81 | for i, net in enumerate(nets): |
| 82 | png = net_drawer.GetPydotGraph(net).create_png() |
| 83 | with open("caffe2_net_forward_" + str(i) + net.Name() + ".png", |
| 84 | 'wb') \ |
| 85 | as f: |
| 86 | f.write(png) |
| 87 | |
| 88 | results = [ |
| 89 | _get_grad(net, net_outputs, |
| 90 | _get_output_with_grad_names(net_outputs), |
| 91 | input_values, inputs_with_grads) |
| 92 | for net, net_outputs in zip(nets, outputs) |
| 93 | ] |
| 94 | |
| 95 | if print_net_images: |
| 96 | _, _, backward_nets = zip(*results) |
| 97 | for i, net in enumerate(backward_nets): |
| 98 | png = net_drawer.GetPydotGraph(net).create_png() |
| 99 | with open("caffe2_net_" + str(i) + net.Name() + ".png", 'wb') \ |
| 100 | as f: |
| 101 | f.write(png) |
| 102 | |
| 103 | first_net_results, first_net_grads, _ = results[0] |
| 104 | for net_results, net_grads, _ in results[1:]: |
| 105 | assert len(net_results) == len(first_net_results) |
| 106 | for idx, ((blob1, blob_value1), (blob2, blob_value2)) in enumerate( |
| 107 | zip(first_net_results, net_results)): |
| 108 | _assert_close( |
| 109 | blob_value1, blob_value2, threshold, |
| 110 | err_msg="Different forward pass results for output id {}. " |
| 111 | "Corresponding output blobs: {} and {}".format( |
| 112 | idx, blob1, blob2)) |
| 113 | |
| 114 | assert net_grads.keys() == first_net_grads.keys() |
| 115 | for blob, blob_grad_value in net_grads.items(): |
| 116 | _assert_close( |
| 117 | first_net_grads[blob], blob_grad_value, threshold, |
| 118 | err_msg="Different gradients for input {}".format(blob)) |
| 119 | |
| 120 | @staticmethod |
| 121 | def Check(net, outputs_with_grad, input_values, |