(data, index, shape, force_format=False)
| 78 | |
| 79 | |
| 80 | def sparse_matrix(data, index, shape, force_format=False): |
| 81 | fmt = index[0] |
| 82 | if fmt != "coo": |
| 83 | raise TypeError( |
| 84 | "Tensorflow backend only supports COO format. But got %s." % fmt |
| 85 | ) |
| 86 | # tf.SparseTensor only supports int64 indexing, |
| 87 | # therefore manually casting to int64 when input in int32 |
| 88 | spmat = tf.SparseTensor( |
| 89 | indices=tf.cast(tf.transpose(index[1], (1, 0)), tf.int64), |
| 90 | values=data, |
| 91 | dense_shape=shape, |
| 92 | ) |
| 93 | return spmat, None |
| 94 | |
| 95 | |
| 96 | def sparse_matrix_indices(spmat): |