MCPcopy Index your code
hub / github.com/pytorch/pytorch / build_conv_model

Function build_conv_model

caffe2/python/test/executor_test_util.py:70–104  ·  view source on GitHub ↗
(model_name, batch_size)

Source from the content-addressed store, hash-verified

68
69
70def build_conv_model(model_name, batch_size):
71 model_gen_map = conv_model_generators()
72 assert model_name in model_gen_map, "Model " + model_name + " not found"
73 model, input_size = model_gen_map[model_name]("NCHW", None)
74
75 input_shape = [batch_size, 3, input_size, input_size]
76 if model_name == "MLP":
77 input_shape = [batch_size, input_size]
78
79 model.param_init_net.GaussianFill(
80 [],
81 "data",
82 shape=input_shape,
83 mean=0.0,
84 std=1.0
85 )
86 model.param_init_net.UniformIntFill(
87 [],
88 "label",
89 shape=[batch_size, ],
90 min=0,
91 max=999
92 )
93
94 model.AddGradientOperators(["loss"])
95
96 ITER = brew.iter(model, "iter")
97 LR = model.net.LearningRate(
98 ITER, "LR", base_lr=-1e-8, policy="step", stepsize=10000, gamma=0.999)
99 ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
100 for param in model.params:
101 param_grad = model.param_to_grad[param]
102 model.net.WeightedSum([param, ONE, param_grad, LR], param)
103
104 return model
105
106
107def build_resnet50_dataparallel_model(

Callers 1

test_executorMethod · 0.90

Calls 3

conv_model_generatorsFunction · 0.85
AddGradientOperatorsMethod · 0.45
iterMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…