MCPcopy
hub / github.com/hpcaitech/ColossalAI / assert_forward_equal

Function assert_forward_equal

tests/test_lazy/lazy_init_utils.py:41–66  ·  view source on GitHub ↗
(
    m1: torch.nn.Module,
    m2: torch.nn.Module,
    data_gen_fn: Callable[[], dict],
    output_transform_fn: Callable[[Any], dict],
)

Source from the content-addressed store, hash-verified

39
40
41def assert_forward_equal(
42 m1: torch.nn.Module,
43 m2: torch.nn.Module,
44 data_gen_fn: Callable[[], dict],
45 output_transform_fn: Callable[[Any], dict],
46) -> None:
47 data = data_gen_fn()
48
49 m1.eval()
50 m2.eval()
51 # run forward
52 with torch.no_grad():
53 outputs1 = m1(**data)
54 outputs2 = m2(**data)
55
56 # compare output
57 transformed_out1 = output_transform_fn(outputs1)
58 transformed_out2 = output_transform_fn(outputs2)
59
60 assert len(transformed_out1) == len(transformed_out2)
61
62 for key, out1 in transformed_out1.items():
63 out2 = transformed_out2[key]
64 assert torch.allclose(
65 out1, out2, atol=1e-5
66 ), f"{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}"
67
68
69def check_lazy_init(

Callers 1

check_lazy_initFunction · 0.85

Calls 4

output_transform_fnFunction · 0.85
no_gradMethod · 0.80
data_gen_fnFunction · 0.50
evalMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…