Source code for ClearMap.ImageProcessing.MachineLearning.VesselFilling.VesselFillingNetwork

#  -*- coding: utf-8 -*-
"""
VesselFillingNetwork
====================

This module implements the vessel filling neuronal network in PyTorch.
"""
__author__    = 'Sophie Skriabin, Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__   = 'GPLv3 - GNU General Pulic License v3 (see LICENSE.txt)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__   = 'http://idisco.info'
__download__  = 'http://www.github.com/ChristophKirst/ClearMap2'


import torch
import torch.nn as nn


###############################################################################
### Convolutional neural network architecture
###############################################################################

[docs]class DeSepConv3d(nn.Module): """Depthwise separable convolutional layer.""" def __init__(self, nin, nout): super(DeSepConv3d, self).__init__() self.depthwise = nn.Conv3d(nin, nin, kernel_size=3, padding=(1,1,1), groups=nin) self.pointwise = nn.Conv3d(nin, nout, kernel_size=1)
[docs] def forward(self, x): out = self.depthwise(x) out = self.pointwise(out) return out
[docs]class VesselFillingNetwork(nn.Module): """Vessel filling neuronal network.""" def __init__(self, load = None): super(VesselFillingNetwork, self).__init__() #architecture self.conv1 = nn.Conv3d(1 , 16, 7, stride=2, padding=(3,3,3), dilation=1) self.conv2 = nn.Conv3d(16, 32, 5, stride=1, padding=(2,2,2), dilation=1) self.desepconv4 = DeSepConv3d(32, 32) self.convbin1 = nn.Conv3d(32, 32, 3, stride=1, padding=(1,1,1), dilation=1) self.conv3 = nn.ConvTranspose3d(32, 16, 7, stride=2, padding=(3,3,3), dilation=1) self.conv4 = nn.Conv3d(17, 2, 3, stride=1, padding=(1,1,1), dilation=1) # non-linearities self.maxpool = nn.MaxPool3d(kernel_size=2) self.relu = nn.ReLU(True) self.dropout = nn.Dropout(p=0.15) self.softmax = nn.Softmax(dim=1) self.out_act = nn.Sigmoid() self.upsample = nn.Upsample(mode="trilinear", scale_factor=2, align_corners=False) if load is not None: map_location = None if torch.cuda.is_available() else torch.device('cpu'); self.load_state_dict(torch.load(load, map_location=map_location));
[docs] def encode(self, inp): skip1 = inp inp = self.maxpool(inp) x = self.relu(self.conv1(inp)) x = self.dropout(x) x = self.maxpool(x) x = self.relu(self.conv2(x)) x = self.dropout(x) x = self.desepconv4(x) x = self.relu(self.convbin1(x)) x = self.dropout(x) x = self.upsample(x) x = self.relu(self.conv3(x)) x = self.dropout(x) x = self.upsample(x) slicing = (slice(None), slice(None)) + tuple(slice(None,min(skips, xs)) for skips,xs in zip(skip1.shape[2:], x.shape[2:])); x = torch.cat((skip1[slicing], x[slicing]), 1) x = self.out_act(4 * (self.conv4(x) - 0.5)) return self.softmax(x)[:,[0]];
[docs] def decode(self, x): return x
[docs] def forward(self, x): x = self.encode(x) return x