考慮一種常用于處理圖像的標(biāo)準(zhǔn)卷積神經(jīng)網(wǎng)絡(luò) (CNN)。輸入是像素網(wǎng)格。每個像素都有一個數(shù)據(jù)值向量,例如紅色、綠色和藍(lán)色通道。數(shù)據(jù)通過一系列卷積層。每層結(jié)合來自像素及其相鄰像素的數(shù)據(jù),為該像素生成新的數(shù)據(jù)向量。前面的卷積層檢測小規(guī)模的局部模式,而后面的卷積層檢測更大、更抽象的模式。通常,卷積層與池化層交替出現(xiàn),池化層在局部區(qū)域執(zhí)行一些操作,例如最大值或最小值。 圖卷積是一種基于圖結(jié)構(gòu)的卷積操作。與傳統(tǒng)的卷積操作不同,圖卷積的輸入是一個圖形,包含節(jié)點和邊,而不是一個二維或三維的張量。圖卷積的目的是通過學(xué)習(xí)節(jié)點之間的關(guān)系來進(jìn)行特征提取和分類等任務(wù)。 圖卷積的原理可以概括為以下幾個步驟: 聚合:對于每個節(jié)點,將其鄰近節(jié)點的特征進(jìn)行聚合,可以使用均值、最大值、加權(quán)和等方式來計算鄰近節(jié)點的特征。 更新:根據(jù)聚合后的鄰居節(jié)點特征以及當(dāng)前節(jié)點自身的特征,更新當(dāng)前節(jié)點的特征表示。 激活:對更新后的節(jié)點特征進(jìn)行激活函數(shù)操作,例如ReLU函數(shù)等。
圖1 圖卷積神經(jīng)網(wǎng)絡(luò)原理
DeepChem有一個名為GraphConvModel的類,它在其內(nèi)部包裝了一個標(biāo)準(zhǔn)的圖卷積網(wǎng)絡(luò)結(jié)構(gòu),以方便用戶使用。 import deepchem as dc
tasks, datasets, transformers = dc.molnet.load_tox21(featurizer='GraphConv') train_dataset, valid_dataset, test_dataset = datasets n_tasks = len(tasks) model = dc.models.GraphConvModel(n_tasks, mode='classification') model.fit(train_dataset, nb_epoch=50) metric = dc.metrics.Metric(dc.metrics.roc_auc_score) print('Training set score:', model.evaluate(train_dataset, [metric], transformers)) print('Test set score:', model.evaluate(test_dataset, [metric], transformers))
DeepChem為圖卷積中涉及的所有計算提供了Keras層。我們將使用以下DeepChem層: GraphConv層:該層實現(xiàn)了圖卷積。圖卷積以非線性方式將每個節(jié)點的特征向量與相鄰節(jié)點的特征向量組合在一起,從而“混合”了圖的局部鄰域信息。 GraphPool層:該層對鄰域中原子的特征向量進(jìn)行最大池化??梢詫⒋藢右暈轭愃朴?D卷積的最大池化層,但是在圖上操作。 GraphGather層:許多圖卷積網(wǎng)絡(luò)會操作每個節(jié)點的特征向量,例如對于分子,每個節(jié)點可能表示一個原子,網(wǎng)絡(luò)會操作匯總該原子局部化學(xué)信息的原子特征向量。然而,在應(yīng)用結(jié)束時,我們可能希望使用分子級別的特征表示。該層通過組合所有節(jié)點級別的特征向量創(chuàng)建一個圖級別的特征向量。 batch_size = 100
class MyGraphConvModel(tf.keras.Model):
def __init__(self): super(MyGraphConvModel, self).__init__() self.gc1 = GraphConv(128, activation_fn=tf.nn.tanh) self.batch_norm1 = layers.BatchNormalization() self.gp1 = GraphPool()
self.gc2 = GraphConv(128, activation_fn=tf.nn.tanh) self.batch_norm2 = layers.BatchNormalization() self.gp2 = GraphPool()
self.dense1 = layers.Dense(256, activation=tf.nn.tanh) self.batch_norm3 = layers.BatchNormalization() self.readout = GraphGather(batch_size=batch_size, activation_fn=tf.nn.tanh)
self.dense2 = layers.Dense(n_tasks*2) self.logits = layers.Reshape((n_tasks, 2)) self.softmax = layers.Softmax()
def call(self, inputs): gc1_output = self.gc1(inputs) batch_norm1_output = self.batch_norm1(gc1_output) gp1_output = self.gp1([batch_norm1_output] + inputs[1:])
gc2_output = self.gc2([gp1_output] + inputs[1:]) batch_norm2_output = self.batch_norm1(gc2_output) gp2_output = self.gp2([batch_norm2_output] + inputs[1:])
dense1_output = self.dense1(gp2_output) batch_norm3_output = self.batch_norm3(dense1_output) readout_output = self.readout([batch_norm3_output] + inputs[1:])
logits_output = self.logits(self.dense2(readout_output)) return self.softmax(logits_output)
model = dc.models.KerasModel(MyGraphConvModel(), loss=dc.models.losses.CategoricalCrossEntropy())
我們現(xiàn)在可以更清晰地看到:有兩個卷積層,每個卷積層由一個GraphConv組成,后跟歸一化,然后是一個GraphPool進(jìn)行最大池化。最后我們使用一個圖聚合層,另一個歸一化,一個GraphGather來組合所有不同節(jié)點的數(shù)據(jù),最后一個全連接層產(chǎn)生全局輸出。
|