Skip to content

How to add a mask branch? #18

@ics091

Description

@ics091

Hello, if I want to add a mask branch on the basis of your target detection code, is the following code correct? The question now is that the mask loss is 0 during training. Thank you!

feature_pyramid = self.build_base_network(input_img_batch)  # [P3, P4, P5, P6, P7]

rpn_cls_score, rpn_cls_prob, rpn_cnt_scores, rpn_box = self.rpn_net(feature_pyramid)

# print('rpn_box:', rpn_box.shape)

rpn_cnt_prob = tf.nn.sigmoid(rpn_cnt_scores)
rpn_cnt_prob = tf.expand_dims(rpn_cnt_prob, axis=2)
rpn_cnt_prob = tf.broadcast_to(rpn_cnt_prob,
                               [self.batch_size, tf.shape(rpn_cls_prob)[1], tf.shape(rpn_cls_prob)[2]])

rpn_prob = rpn_cls_prob * rpn_cnt_prob

ftmaps = []
for i in range(3, 8):
    p = 'P%d'%i
    ftmaps.append(feature_pyramid[p])

# MASK
with tf.variable_scope('mask_target', reuse=tf.AUTO_REUSE):
    # rpn_box:  (2, ?, 4)
    final_box = []
    for i in range(self.batch_size):
        boxes, _, _ = postprocess_detctions(rpn_bbox=rpn_box[i, :, :],
                                            rpn_cls_prob=rpn_prob[i, :, :],
                                            img_shape=img_shape,
                                            is_training=self.is_training)
        final_box.append(boxes)
    final_box = tf.stack(final_box, axis=0)

    # rois:  (2, ?, 14, 14, 256)
    croped_rois = self.PyramidROIAlign(final_box, ftmaps, img_shape)

    # print('rpn_box: ', final_box.shape)
    # print('croped_rois: ', croped_rois.shape)
    mask = []
    for i in range(self.batch_size):
        # print('m: ', croped_rois[i].shape)
        m = croped_rois[i]
        for _ in range(4):
            m = slim.conv2d(m, 256, [3, 3], stride=1, padding='SAME', activation_fn=tf.nn.relu)
        # to 28 x 28
        m = slim.conv2d_transpose(m, 256, 2, stride=2, padding='VALID', activation_fn=tf.nn.relu)
        tf.add_to_collection('__TRANSPOSED__', m)
        m = slim.conv2d(m, cfgs.CLASS_NUM + 1, [1, 1], stride=1, padding='VALID', activation_fn=None)
        m = tf.nn.sigmoid(m)
        mask.append(m)
    
    mask = tf.stack(mask, axis=0)
    # mask:  (2, ?, 28, 28, 81)
    # print('mask: ', mask.shape)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions