TensorFlow之sparse tensor
媒介sparse tensor 希罕tensor, tensor中大部门元素是0, 少部门元素黑白0.
创建sparse tensor
import tensorflow as tf
# indices指示正常值的索引, 即哪些索引位置上是正常值.
# values表示这些正常值是多少.
# indices和values是一一对应的. 表示第0行第1列的值是1, 表示第一行第0列的值是2, 表示第2行第3列的值是3, 以此类推.
# dense_shape表示这个SparseTensor的shape是多少
s = tf.SparseTensor(indices = [, , ],
values = ,
dense_shape = )
print(s)
# 把sparse tensor转化为稠密矩阵
print(tf.sparse.to_dense(s))
结果如下:
SparseTensor(indices=tf.Tensor(
[
], shape=(3, 2), dtype=int64), values=tf.Tensor(, shape=(3,), dtype=float32), dense_shape=tf.Tensor(, shape=(2,), dtype=int64))
tf.Tensor(
[
], shape=(3, 4), dtype=float32)
sparse tensor的运算:
import tensorflow as tf
# indices指示正常值的索引, 即哪些索引位置上是正常值.
# values表示这些正常值是多少.
# indices和values是一一对应的. 表示第0行第1列的值是1, 表示第一行第0列的值是2, 表示第2行第3列的值是3, 以此类推.
# dense_shape表示这个SparseTensor的shape是多少
s = tf.SparseTensor(indices = [, , ],
values = ,
dense_shape = )
print(s)
# 乘法
s2 = s * 2.0
print(s2)
try:
# 加法不支持.
s3 = s + 1
except TypeError as ex:
print(ex)
s4 = tf.constant([,
,
,
])
# 得到一个3 * 2 的矩阵
print(tf.sparse.sparse_dense_matmul(s, s4))
结果如下:
SparseTensor(indices=tf.Tensor(
[
], shape=(3, 2), dtype=int64), values=tf.Tensor(, shape=(3,), dtype=float32), dense_shape=tf.Tensor(, shape=(2,), dtype=int64))
SparseTensor(indices=tf.Tensor(
[
], shape=(3, 2), dtype=int64), values=tf.Tensor(, shape=(3,), dtype=float32), dense_shape=tf.Tensor(, shape=(2,), dtype=int64))
unsupported operand type(s) for +: 'SparseTensor' and 'int'
tf.Tensor(
[[ 30.40.]
[ 20.40.]
], shape=(3, 2), dtype=float32)
注意在定义sparse tensor的时候 indices必须是排好序的. 假如不是,定义的时候不会报错, 但是在to_dense的时候会报错:
import tensorflow as tf
s5 = tf.SparseTensor(indices = [, , ],
values = ,
dense_shape = )
print(s5)
#print(tf.sparse.to_dense(s5)) #报错:indices = is out of order. Many sparse ops require sorted indices.
# 可以通过reorder对排序, 这样to_dense就没问题了.
s6 = tf.sparse.reorder(s5)
print(tf.sparse.to_dense(s6))
结果如下:
SparseTensor(indices=tf.Tensor(
[
], shape=(3, 2), dtype=int64), values=tf.Tensor(, shape=(3,), dtype=float32), dense_shape=tf.Tensor(, shape=(2,), dtype=int64))
tf.Tensor(
[
], shape=(3, 4), dtype=float32)
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]