Calculate 2d loss, normally it's durrations/f0s/energys loss.
(y_gt, y_pred, loss_fn)
| 52 | |
| 53 | |
| 54 | def calculate_2d_loss(y_gt, y_pred, loss_fn): |
| 55 | """Calculate 2d loss, normally it's durrations/f0s/energys loss.""" |
| 56 | y_gt_T = tf.shape(y_gt)[1] |
| 57 | y_pred_T = tf.shape(y_pred)[1] |
| 58 | |
| 59 | # there is a mismath length when training multiple GPU. |
| 60 | # we need slice the longer tensor to make sure the loss |
| 61 | # calculated correctly. |
| 62 | if y_gt_T > y_pred_T: |
| 63 | y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T]) |
| 64 | elif y_pred_T > y_gt_T: |
| 65 | y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T]) |
| 66 | |
| 67 | loss = loss_fn(y_gt, y_pred) |
| 68 | if isinstance(loss, tuple) is False: |
| 69 | loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B] |
| 70 | else: |
| 71 | loss = list(loss) |
| 72 | for i in range(len(loss)): |
| 73 | loss[i] = tf.reduce_mean( |
| 74 | loss[i], list(range(1, len(loss[i].shape))) |
| 75 | ) # shape = [B] |
| 76 | |
| 77 | return loss |
no outgoing calls
no test coverage detected