MCPcopy
hub / github.com/snap-stanford/GraphGym / create_model

Function create_model

graphgym/model_builder.py:13–31  ·  view source on GitHub ↗

r""" Create model for graph machine learning Args: to_device (string): The devide that the model will be transferred to dim_in (int, optional): Input dimension to the model dim_out (int, optional): Output dimension to the model

(to_device=True, dim_in=None, dim_out=None)

Source from the content-addressed store, hash-verified

11
12
13def create_model(to_device=True, dim_in=None, dim_out=None):
14 r"""
15 Create model for graph machine learning
16
17 Args:
18 to_device (string): The devide that the model will be transferred to
19 dim_in (int, optional): Input dimension to the model
20 dim_out (int, optional): Output dimension to the model
21 """
22 dim_in = cfg.share.dim_in if dim_in is None else dim_in
23 dim_out = cfg.share.dim_out if dim_out is None else dim_out
24 # binary classification, output dim = 1
25 if 'classification' in cfg.dataset.task_type and dim_out == 2:
26 dim_out = 1
27
28 model = network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out)
29 if to_device:
30 model.to(torch.device(cfg.device))
31 return model

Callers 2

get_statsFunction · 0.90
main.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected