I was following this tutorial (https://tree.rocks/a-simple-way-to-understand-and-implement-object-detection-from-scratch-by-pure-cnn-36cc28143ca8) and tried to implement a gating function for my own CNN. My goal is to classify if an image has smoke, then to draw a bounding box. If there is no smoke, no bounding box should be drawn. Without the gaiting function, my code works perfectly, but it draws a box even if there is no smoke:
Working code:
class CNN_subclass(Model):
def __init__(self):
super().__init__()
self.conv1 = Conv2D(32, (3,3), padding = 'same', activation = 'relu', input_shape = (128,128,3))
self.maxp1 = MaxPooling2D(pool_size = (2,2), padding = 'same')
self.drop = Dropout(0.1)
self.flat1 = Flatten()
self.dense1 = Dense(190, activation = 'relu')
self.dense2 = Dense(190, activation = 'relu')
self.dense3 = Dense(190, activation = 'relu')
self.dense4 = Dense(190, activation = 'relu')
self.dense5 = Dense(190, activation = 'relu')
self.dense6 = Dense(190, activation = 'relu')
self.flat2 = Flatten()
self.prob = Dense(2, activation='softmax')
self.boxes = Dense(4)
self.cat = Concatenate()
def call(self, input):
x = self.conv1(input)
x = self.maxp1(x)
x = self.drop(x)
x = self.flat1(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
x = self.dense4(x)
x = self.dense5(x)
x = self.dense6(x)
x = self.flat2(x)
x_prob = self.prob(x)
x_boxes = self.boxes(x)
return [x_prob, x_boxes]
Here is the code with my gating function (it does not work):
class CNN_subclass(Model):
def __init__(self):
super().__init__()
self.conv1 = Conv2D(32, (3,3), padding = 'same', activation = 'relu', input_shape = (128,128,3))
self.maxp1 = MaxPooling2D(pool_size = (2,2), padding = 'same')
self.drop = Dropout(0.1)
self.flat1 = Flatten()
self.dense1 = Dense(190, activation = 'relu')
self.dense2 = Dense(190, activation = 'relu')
self.dense3 = Dense(190, activation = 'relu')
self.dense4 = Dense(190, activation = 'relu')
self.dense5 = Dense(190, activation = 'relu')
self.dense6 = Dense(190, activation = 'relu')
self.flat2 = Flatten()
self.prob = Dense(2, activation='softmax')
self.boxes = Dense(4)
self.cat = Concatenate()
def gating(self, prob, box):
gate = tf.cond(tf.math.equal(tf.math.argmax(prob),0), tf.ones(prob), tf.zeros(prob))
return box * gate
def call(self, input):
x = self.conv1(input)
x = self.maxp1(x)
x = self.drop(x)
x = self.flat1(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
x = self.dense4(x)
x = self.dense5(x)
x = self.dense6(x)
x = self.flat2(x)
x_prob = self.prob(x)
x_boxes = self.boxes(x)
gate = self.gating(x_prob, x_boxes)
return [x_prob, gate]
How do I fix my gating function? Thank you so much!
Ryan is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.