(net, schema_or_record, enforce_types=False)
| 1224 | |
| 1225 | |
| 1226 | def InitEmptyRecord(net, schema_or_record, enforce_types=False): |
| 1227 | if not schema_or_record.has_blobs(): |
| 1228 | record = NewRecord(net, schema_or_record) |
| 1229 | else: |
| 1230 | record = schema_or_record |
| 1231 | |
| 1232 | for blob_type, blob in zip(record.field_types(), record.field_blobs()): |
| 1233 | try: |
| 1234 | data_type = data_type_for_dtype(blob_type) |
| 1235 | shape = [0] + list(blob_type.shape) |
| 1236 | net.ConstantFill([], blob, shape=shape, dtype=data_type) |
| 1237 | except TypeError: |
| 1238 | logger.warning("Blob {} has type error".format(blob)) |
| 1239 | # If data_type_for_dtype doesn't know how to resolve given numpy |
| 1240 | # type to core.DataType, that function can throw type error (for |
| 1241 | # example that would happen for cases of unknown types such as |
| 1242 | # np.void). This is not a problem for cases when the record if going |
| 1243 | # to be overwritten by some operator later, though it might be an |
| 1244 | # issue for type/shape inference. |
| 1245 | if enforce_types: |
| 1246 | raise |
| 1247 | # If we don't enforce types for all items we'll create a blob with |
| 1248 | # the default ConstantFill (FLOAT, no shape) |
| 1249 | net.ConstantFill([], blob, shape=[0]) |
| 1250 | |
| 1251 | return record |
| 1252 | |
| 1253 | |
| 1254 | _DATA_TYPE_FOR_DTYPE = [ |
searching dependent graphs…