4 - 索引与切片
1个月前 • 151次点击 • 来自 TensorFlow
收录专题: TensorFlow入门笔记
1. basic indexing
a = tf.ones([1,5,5,3])
print(a[0][0])
print(a[0][0][0])
print(a[0][0][0][2])
2. numpy-style indexing
a = tf.random.normal([4,28,28,3])
print(a[1].shape)
print(a[1,2].shape)
print(a[1,2,3].shape)
print(a[1,2,3,2].shape)
3. start:end
语法同python数组切片 [A:B)
a=tf.range(10)
print(a)
print(a[-1:])
print(a[-2:])
print(a[:2])
print(a[:-1])
4. indexing by :
a = tf.random.normal([4,28,28,3])
print(a[0,:,:,:].shape)
print(a[0,1,:,:].shape)
print(a[:,:,:,0].shape)
print(a[:,:,:,2].shape)
print(a[:,0,:,:].shape)
5. indexing by :: 隔行采样
a = tf.random.normal([4,28,28,3])
print(a[:,0:28:2,0:28:2,:].shape)
print(a[:,:14,:14,:].shape)
print(a[:,14:,14:,:].shape)
print(a[:,::2,::2,:].shape)
6. indexing by ...
a = tf.random.normal([2,4,28,28,3])
print(a[0,:,:,:,:].shape)
print(a[0,...].shape)
print(a[:,:,:,:,0].shape)
print(a[...,0].shape)
print(a[0,...,2].shape)
print(a[1,0,...,0].shape)
7. selective indexing
使用tf.gather、tf.gather_nd、tf.boolean_mask进行随机采样
(1)tf.gather(在某一维度指定index)
# 下面的tensor即表示,4个班级,每个班级35名学生,每个学生8门课的成绩
a = tf.random.normal([4,35,8])
# axis表示维度,indices表示在axis维度上要取数据的索引
print(tf.gather(a,axis=0,indices=[2,3]).shape) # 可理解为取第2、3个班级的学生成绩,同a[2:4].shape
print(tf.gather(a,axis=0,indices=[2,1,3,0]).shape) # 可理解为依次取第2、1、3、0个班级的学生成绩
print(tf.gather(a,axis=1,indices=[2,3,7,9,16]).shape) # 可理解为取所有班级第2,3,7,9,16个学生的成绩
print(tf.gather(a,axis=2,indices=[2,3,7]).shape) # 可理解为取所有班级所有学生第2,3,7门课的成绩
(2)tf.gather_nd(在多个维度指定index)
a = tf.random.normal([4,35,8])
# axis表示维度,indices表示在axis维度上要取数据的索引
print(tf.gather_nd(a,[0]).shape) # 可理解为取0号班级的所有成绩
print(tf.gather_nd(a,[0,1]).shape) # 可理解为取0号班级1号学生的成绩
print(tf.gather_nd(a,[0,1,2]).shape) # 可理解为取0号班级1号学生的第2门课成绩
print(tf.gather_nd(a,[[0,0],[1,1]]).shape) # 可理解为取0号班级0号学生和1号班级1号学生的成绩
print(tf.gather_nd(a,[[0,0],[1,1],[2,2]]).shape) # 可理解为取0号班级0号学生、1号班级1号学生、2号班级2号学生的成绩
print(tf.gather_nd(a,[[0,0,0],[1,1,1],[2,2,2]]).shape) # 可理解为0班0学0课,1班1学1课,2班2学2课的成绩
print(tf.gather_nd(a,[[[0,0,0],[1,1,1],[2,2,2]]]).shape) # shape与上不同
(3)tf.boolean_mask(通过True和False的方式选择数据)
a = tf.random.normal([4,28,28,3])
print(tf.boolean_mask(a,mask=[True,True,False,False]).shape)
print(tf.boolean_mask(a,mask=[True,True,False],axis=3).shape)
a = tf.ones([2,3,4])
print(tf.boolean_mask(a,mask=[[True,False,False],[False,True,True]]))
标签