MCPcopy
hub / github.com/TensorSpeech/TensorFlowTTS / calculate_2d_loss

Function calculate_2d_loss

tensorflow_tts/utils/strategy.py:54–77  ·  view source on GitHub ↗

Calculate 2d loss, normally it's durrations/f0s/energys loss.

(y_gt, y_pred, loss_fn)

Source from the content-addressed store, hash-verified

52
53
54def calculate_2d_loss(y_gt, y_pred, loss_fn):
55 """Calculate 2d loss, normally it's durrations/f0s/energys loss."""
56 y_gt_T = tf.shape(y_gt)[1]
57 y_pred_T = tf.shape(y_pred)[1]
58
59 # there is a mismath length when training multiple GPU.
60 # we need slice the longer tensor to make sure the loss
61 # calculated correctly.
62 if y_gt_T > y_pred_T:
63 y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T])
64 elif y_pred_T > y_gt_T:
65 y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T])
66
67 loss = loss_fn(y_gt, y_pred)
68 if isinstance(loss, tuple) is False:
69 loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
70 else:
71 loss = list(loss)
72 for i in range(len(loss)):
73 loss[i] = tf.reduce_mean(
74 loss[i], list(range(1, len(loss[i].shape)))
75 ) # shape = [B]
76
77 return loss

Calls

no outgoing calls

Tested by

no test coverage detected