Represents a framework-agnostic neural network layer in a directed graph.
| 56 | ########################################################################### |
| 57 | |
| 58 | class Node(): |
| 59 | """Represents a framework-agnostic neural network layer in a directed graph.""" |
| 60 | |
| 61 | def __init__(self, uid, name, op, output_shape=None, params=None): |
| 62 | """ |
| 63 | uid: unique ID for the layer that doesn't repeat in the computation graph. |
| 64 | name: Name to display |
| 65 | op: Framework-agnostic operation name. |
| 66 | """ |
| 67 | self.id = uid |
| 68 | self.name = name # TODO: clarify the use of op vs name vs title |
| 69 | self.op = op |
| 70 | self.repeat = 1 |
| 71 | if output_shape: |
| 72 | assert isinstance(output_shape, (tuple, list)),\ |
| 73 | "output_shape must be a tuple or list but received {}".format(type(output_shape)) |
| 74 | self.output_shape = output_shape |
| 75 | self.params = params if params else {} |
| 76 | self._caption = "" |
| 77 | |
| 78 | @property |
| 79 | def title(self): |
| 80 | # Default |
| 81 | title = self.name or self.op |
| 82 | |
| 83 | if "kernel_shape" in self.params: |
| 84 | # Kernel |
| 85 | kernel = self.params["kernel_shape"] |
| 86 | title += "x".join(map(str, kernel)) |
| 87 | if "stride" in self.params: |
| 88 | stride = self.params["stride"] |
| 89 | if np.unique(stride).size == 1: |
| 90 | stride = stride[0] |
| 91 | if stride != 1: |
| 92 | title += "/s{}".format(str(stride)) |
| 93 | # # Transposed |
| 94 | # if node.transposed: |
| 95 | # name = "Transposed" + name |
| 96 | return title |
| 97 | |
| 98 | @property |
| 99 | def caption(self): |
| 100 | if self._caption: |
| 101 | return self._caption |
| 102 | |
| 103 | caption = "" |
| 104 | |
| 105 | # Stride |
| 106 | # if "stride" in self.params: |
| 107 | # stride = self.params["stride"] |
| 108 | # if np.unique(stride).size == 1: |
| 109 | # stride = stride[0] |
| 110 | # if stride != 1: |
| 111 | # caption += "/{}".format(str(stride)) |
| 112 | return caption |
| 113 | |
| 114 | def __repr__(self): |
| 115 | args = (self.op, self.name, self.id, self.title, self.repeat) |
no outgoing calls
no test coverage detected