MCPcopy Index your code
hub / github.com/pytorch/pytorch / test_multi_instance

Method test_multi_instance

caffe2/python/net_builder_test.py:209–247  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

207 self.assertEqual(PythonOpStats.num_calls, 256)
208
209 def test_multi_instance(self):
210 NUM_INSTANCES = 10
211 NUM_ITERS = 15
212 with TaskGroup() as tg:
213 with Task(num_instances=NUM_INSTANCES):
214 with ops.task_init():
215 counter1 = ops.CreateCounter([], ['global_counter'])
216 counter2 = ops.CreateCounter([], ['global_counter2'])
217 counter3 = ops.CreateCounter([], ['global_counter3'])
218 # both task_counter and local_counter should be thread local
219 with ops.task_instance_init():
220 task_counter = ops.CreateCounter([], ['task_counter'])
221 local_counter = ops.CreateCounter([], ['local_counter'])
222 with ops.loop(NUM_ITERS):
223 ops.CountUp(counter1)
224 ops.CountUp(task_counter)
225 ops.CountUp(local_counter)
226 # gather sum of squares of local counters to make sure that
227 # each local counter counted exactly up to NUM_ITERS, and
228 # that there was no false sharing of counter instances.
229 with ops.task_instance_exit():
230 count2 = ops.RetrieveCount(task_counter)
231 with ops.loop(ops.Mul([count2, count2])):
232 ops.CountUp(counter2)
233 # This should have the same effect as the above
234 count3 = ops.RetrieveCount(local_counter)
235 with ops.loop(ops.Mul([count3, count3])):
236 ops.CountUp(counter3)
237 # The code below will only run once
238 with ops.task_exit():
239 total1 = final_output(ops.RetrieveCount(counter1))
240 total2 = final_output(ops.RetrieveCount(counter2))
241 total3 = final_output(ops.RetrieveCount(counter3))
242
243 with LocalSession() as session:
244 session.run(tg)
245 self.assertEqual(total1.fetch(), NUM_INSTANCES * NUM_ITERS)
246 self.assertEqual(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
247 self.assertEqual(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
248
249 def test_if_net(self):
250 with NetBuilder() as nb:

Callers

nothing calls this directly

Calls 12

TaskGroupClass · 0.90
TaskClass · 0.90
final_outputFunction · 0.90
LocalSessionClass · 0.90
task_initMethod · 0.80
task_instance_initMethod · 0.80
task_instance_exitMethod · 0.80
task_exitMethod · 0.80
loopMethod · 0.45
runMethod · 0.45
assertEqualMethod · 0.45
fetchMethod · 0.45

Tested by

no test coverage detected