MCPcopy Index your code
hub / github.com/pytorch/pytorch / test_if_net

Method test_if_net

caffe2/python/net_builder_test.py:249–311  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

247 self.assertEqual(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
248
249 def test_if_net(self):
250 with NetBuilder() as nb:
251 x0 = ops.Const(0)
252 x1 = ops.Const(1)
253 x2 = ops.Const(2)
254 y0 = ops.Const(0)
255 y1 = ops.Const(1)
256 y2 = ops.Const(2)
257
258 # basic logic
259 first_res = ops.Const(0)
260 with ops.IfNet(ops.Const(True)):
261 ops.Const(1, blob_out=first_res)
262 with ops.Else():
263 ops.Const(2, blob_out=first_res)
264
265 second_res = ops.Const(0)
266 with ops.IfNet(ops.Const(False)):
267 ops.Const(1, blob_out=second_res)
268 with ops.Else():
269 ops.Const(2, blob_out=second_res)
270
271 # nested and sequential ifs,
272 # empty then/else,
273 # passing outer blobs into branches,
274 # writing into outer blobs, incl. into input blob
275 # using local blobs
276 with ops.IfNet(ops.LT([x0, x1])):
277 local_blob = ops.Const(900)
278 ops.Add([ops.Const(100), local_blob], [y0])
279
280 gt = ops.GT([x1, x2])
281 with ops.IfNet(gt):
282 # empty then
283 pass
284 with ops.Else():
285 ops.Add([y1, local_blob], [local_blob])
286 ops.Add([ops.Const(100), y1], [y1])
287
288 with ops.IfNet(ops.EQ([local_blob, ops.Const(901)])):
289 ops.Const(7, blob_out=y2)
290 ops.Add([y1, y2], [y2])
291 with ops.Else():
292 # empty else
293 pass
294
295 plan = Plan('if_net_test')
296 plan.AddStep(to_execution_step(nb))
297 ws = workspace.C.Workspace()
298 ws.run(plan)
299
300 first_res_value = ws.blobs[str(first_res)].fetch()
301 second_res_value = ws.blobs[str(second_res)].fetch()
302 y0_value = ws.blobs[str(y0)].fetch()
303 y1_value = ws.blobs[str(y1)].fetch()
304 y2_value = ws.blobs[str(y2)].fetch()
305
306 self.assertEqual(first_res_value, 1)

Callers

nothing calls this directly

Calls 11

AddStepMethod · 0.95
NetBuilderClass · 0.90
PlanClass · 0.90
to_execution_stepFunction · 0.90
ConstMethod · 0.80
IfNetMethod · 0.80
ElseMethod · 0.45
AddMethod · 0.45
runMethod · 0.45
fetchMethod · 0.45
assertEqualMethod · 0.45

Tested by

no test coverage detected