MCPcopy Index your code
hub / github.com/ZhengPeng7/BiRefNet / get_patches_batch

Method get_patches_batch

models/birefnet.py:200–210  ·  view source on GitHub ↗
(self, x, p)

Source from the content-addressed store, hash-verified

198
199
200 def get_patches_batch(self, x, p):
201 _size_h, _size_w = p.shape[2:]
202 patches_batch = []
203 for idx in range(x.shape[0]):
204 columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
205 patches_x = []
206 for column_x in columns_x:
207 patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
208 patch_sample = torch.cat(patches_x, dim=1)
209 patches_batch.append(patch_sample)
210 return torch.cat(patches_batch, dim=0)
211
212 def forward(self, features):
213 if self.training and self.config.out_ref:

Callers 1

forwardMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected