xxxxxxxxxx
import tensorflow as tf
x = tf.constant([[4.0, 43.0, 45.0],
[2.0, 22.0, 6664.0],
[-4543.0, 0.0, 43.0]])
value = 45.0
indices = [1, 1]
by_indices = tf.tensor_scatter_nd_update(x, [indices], [value])
tf.print('Using indices\n', by_indices, '\n')
by_value = tf.where(tf.equal(x, 22.0), value, x)
tf.print('Using value\n', by_value)
xxxxxxxxxx
import tensorflow as tf
x = tf.constant([[4.0, 43.0, 45.0],
[2.0, 22.0, 6664.0],
[-4543.0, 0.0, 43.0]])
value = 45.0
indices = [1, 1]
by_indices = tf.tensor_scatter_nd_update(x, [indices], [value])
tf.print('Using indices\n', by_indices, '\n')
by_value = tf.where(tf.equal(x, 22.0), value, x)
tf.print('Using value\n', by_value)
result:
"""
Using indices
[[4 43 45]
[2 45 6664]
[-4543 0 43]]
Using value
[[4 43 45]
[2 45 6664]
[-4543 0 43]]
"""