MCPcopy
hub / github.com/Vchitect/Latte / clip_grad_norm_

Function clip_grad_norm_

utils.py:72–125  ·  view source on GitHub ↗

r""" Copy from torch.nn.utils.clip_grad_norm_ Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or T

(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False, clip_grad = True)

Source from the content-addressed store, hash-verified

70 return total_norm
71
72def clip_grad_norm_(
73 parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
74 error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
75 r"""
76 Copy from torch.nn.utils.clip_grad_norm_
77
78 Clips gradient norm of an iterable of parameters.
79
80 The norm is computed over all gradients together, as if they were
81 concatenated into a single vector. Gradients are modified in-place.
82
83 Args:
84 parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
85 single Tensor that will have gradients normalized
86 max_norm (float or int): max norm of the gradients
87 norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
88 infinity norm.
89 error_if_nonfinite (bool): if True, an error is thrown if the total
90 norm of the gradients from :attr:`parameters` is ``nan``,
91 ``inf``, or ``-inf``. Default: False (will switch to True in the future)
92
93 Returns:
94 Total norm of the parameter gradients (viewed as a single vector).
95 """
96 if isinstance(parameters, torch.Tensor):
97 parameters = [parameters]
98 grads = [p.grad for p in parameters if p.grad is not None]
99 max_norm = float(max_norm)
100 norm_type = float(norm_type)
101 if len(grads) == 0:
102 return torch.tensor(0.)
103 device = grads[0].device
104 if norm_type == inf:
105 norms = [g.detach().abs().max().to(device) for g in grads]
106 total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
107 else:
108 total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
109
110 if clip_grad:
111 if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
112 raise RuntimeError(
113 f'The total norm of order {norm_type} for gradients from '
114 '`parameters` is non-finite, so it cannot be clipped. To disable '
115 'this error and scale the gradients by the non-finite norm anyway, '
116 'set `error_if_nonfinite=False`')
117 clip_coef = max_norm / (total_norm + 1e-6)
118 # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
119 # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
120 # when the gradients do not reside in CPU memory.
121 clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
122 for g in grads:
123 g.detach().mul_(clip_coef_clamped.to(g.device))
124 # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
125 return total_norm
126
127def get_experiment_dir(root_dir, args):
128 # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained:

Callers 4

mainFunction · 0.90
mainFunction · 0.90
training_stepMethod · 0.90
training_stepMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected