(data_format, keras_mode=True)
| 97 | |
| 98 | |
| 99 | def get_data_format(data_format, keras_mode=True): |
| 100 | if keras_mode: |
| 101 | dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'} |
| 102 | else: |
| 103 | dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'} |
| 104 | ret = dic.get(data_format, data_format) |
| 105 | if ret not in dic.values(): |
| 106 | raise ValueError("Unknown data_format: {}".format(data_format)) |
| 107 | return ret |
| 108 | |
| 109 | |
| 110 | def shape4d(a, data_format='NHWC'): |
no test coverage detected