Internal function to batch feature dictionaries. Parameters ---------- frames : list[Frame] List of frames keys : list[str] Feature keys. Can be '__ALL__', meaning batching all features. feat_dict_name : str Name of the feature dictionary for reporting er
(frames, keys, feat_dict_name)
| 223 | |
| 224 | |
| 225 | def _batch_feat_dicts(frames, keys, feat_dict_name): |
| 226 | """Internal function to batch feature dictionaries. |
| 227 | |
| 228 | Parameters |
| 229 | ---------- |
| 230 | frames : list[Frame] |
| 231 | List of frames |
| 232 | keys : list[str] |
| 233 | Feature keys. Can be '__ALL__', meaning batching all features. |
| 234 | feat_dict_name : str |
| 235 | Name of the feature dictionary for reporting errors. |
| 236 | |
| 237 | Returns |
| 238 | ------- |
| 239 | dict[str, Tensor] |
| 240 | New feature dict. |
| 241 | """ |
| 242 | if len(frames) == 0: |
| 243 | return {} |
| 244 | schemas = [frame.schemes for frame in frames] |
| 245 | # sanity checks |
| 246 | if is_all(keys): |
| 247 | utils.check_all_same_schema(schemas, feat_dict_name) |
| 248 | keys = schemas[0].keys() |
| 249 | else: |
| 250 | utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name) |
| 251 | # concat features |
| 252 | ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys} |
| 253 | return ret_feat |
| 254 | |
| 255 | |
| 256 | def unbatch(g, node_split=None, edge_split=None): |