(self)
| 207 | self.assertEqual(PythonOpStats.num_calls, 256) |
| 208 | |
| 209 | def test_multi_instance(self): |
| 210 | NUM_INSTANCES = 10 |
| 211 | NUM_ITERS = 15 |
| 212 | with TaskGroup() as tg: |
| 213 | with Task(num_instances=NUM_INSTANCES): |
| 214 | with ops.task_init(): |
| 215 | counter1 = ops.CreateCounter([], ['global_counter']) |
| 216 | counter2 = ops.CreateCounter([], ['global_counter2']) |
| 217 | counter3 = ops.CreateCounter([], ['global_counter3']) |
| 218 | # both task_counter and local_counter should be thread local |
| 219 | with ops.task_instance_init(): |
| 220 | task_counter = ops.CreateCounter([], ['task_counter']) |
| 221 | local_counter = ops.CreateCounter([], ['local_counter']) |
| 222 | with ops.loop(NUM_ITERS): |
| 223 | ops.CountUp(counter1) |
| 224 | ops.CountUp(task_counter) |
| 225 | ops.CountUp(local_counter) |
| 226 | # gather sum of squares of local counters to make sure that |
| 227 | # each local counter counted exactly up to NUM_ITERS, and |
| 228 | # that there was no false sharing of counter instances. |
| 229 | with ops.task_instance_exit(): |
| 230 | count2 = ops.RetrieveCount(task_counter) |
| 231 | with ops.loop(ops.Mul([count2, count2])): |
| 232 | ops.CountUp(counter2) |
| 233 | # This should have the same effect as the above |
| 234 | count3 = ops.RetrieveCount(local_counter) |
| 235 | with ops.loop(ops.Mul([count3, count3])): |
| 236 | ops.CountUp(counter3) |
| 237 | # The code below will only run once |
| 238 | with ops.task_exit(): |
| 239 | total1 = final_output(ops.RetrieveCount(counter1)) |
| 240 | total2 = final_output(ops.RetrieveCount(counter2)) |
| 241 | total3 = final_output(ops.RetrieveCount(counter3)) |
| 242 | |
| 243 | with LocalSession() as session: |
| 244 | session.run(tg) |
| 245 | self.assertEqual(total1.fetch(), NUM_INSTANCES * NUM_ITERS) |
| 246 | self.assertEqual(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2)) |
| 247 | self.assertEqual(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2)) |
| 248 | |
| 249 | def test_if_net(self): |
| 250 | with NetBuilder() as nb: |
nothing calls this directly
no test coverage detected