(self, params)
| 67 | return logs |
| 68 | |
| 69 | def build_inputs(self, params): |
| 70 | |
| 71 | def generate_data(_): |
| 72 | x = tf.zeros(shape=(2,), dtype=tf.float32) |
| 73 | label = tf.zeros([1], dtype=tf.int32) |
| 74 | return x, label |
| 75 | |
| 76 | dataset = tf.data.Dataset.range(1) |
| 77 | dataset = dataset.repeat() |
| 78 | dataset = dataset.map( |
| 79 | generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
| 80 | return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True) |
| 81 | |
| 82 | def aggregate_logs(self, state, step_outputs): |
| 83 | if state is None: |