Source code for ClearMap.ImageProcessing.Skeletonization.SkeletonProcessing

"""
SkeletonProcessing
==================

Utils to post process skeletons.
"""
__author__    = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__   = 'GPLv3 - GNU General Pulic License v3 (see LICENSE.txt)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__   = 'http://idisco.info'
__download__  = 'http://www.github.com/ChristophKirst/ClearMap2'

import numpy as np

import ClearMap.ParallelProcessing.DataProcessing.ConvolvePointList as cpl;
import ClearMap.ImageProcessing.Topology.Topology3d as t3d

import ClearMap.Utils.Timer as tmr;

###############################################################################
### Topology
###############################################################################

[docs]def clean_open_branches(skeleton, skelton_copy, points, radii, length, clean = True, verbose = False): """Branch cleaning via subsequent erosion of end points.""" assert np.isfortran(skeleton); assert np.isfortran(skelton_copy); timer = tmr.Timer(); timer_all = tmr.Timer(); # find branch and end points deg = cpl.convolve_3d_indices(skeleton, t3d.n26, points, sink_dtype = 'uint8'); branchpoints = points[deg >= 3]; e_pts = points[deg == 1]; if verbose: timer.printElapsedTime('Detected %d branch and %d endpoints' % (branchpoints.shape[0], e_pts.shape[0])); timer.reset(); #prepare temps #skel = skeleton.copy(); skel_flat = np.reshape(skelton_copy, -1, order = 'A'); strides = np.array(skelton_copy.strides); if verbose: timer.printElapsedTime('Detected %d branch and %d endpoints' % (branchpoints.shape[0], e_pts.shape[0])); timer.reset(); label = np.arange(27); label = label.reshape([3,3,3]); label[1,1,1] = 0; critical_points = [e_pts]; delete_points = []; for l in range(1, length + 1): #neighbours of end points e_pts_label = cpl.convolve_3d_indices(skelton_copy, label, e_pts); if verbose: timer.printElapsedTime('Done labeling %d / %d' % (l, length)); timer.reset(); #label zero points are non-critical short isolated branches e_pts_zero = e_pts_label == 0; #print 'zero length:', np.unravel_index(e_pts[e_pts_zero], skel.shape) if e_pts_zero.sum() > 0: keep = np.logical_not(e_pts_zero); for m in range(l): critical_points[m] = critical_points[m][keep]; e_pts_label = e_pts_label[keep]; e_pts = e_pts[keep]; if verbose: timer.printElapsedTime('Ignored %d small branches' % (keep.sum())); timer.reset(); e_pts_new = e_pts + np.sum((np.vstack(np.unravel_index(e_pts_label, label.shape)) - 1).T * strides, axis = 1) # did we hit a branch point delete = np.in1d(e_pts_new, branchpoints); #, assume_unique = True); keep = np.logical_not(delete); #print delete.shape, keep.shape, e_pts_new.shape #delete all path that hit a branch point if delete.sum() > 0: for m in range(l): delete_points.append(critical_points[m][delete]); #print 'deleting:', np.unravel_index(critical_points[m][delete], skel.shape) critical_points[m] = critical_points[m][keep]; e_pts_new = e_pts_new[keep]; if verbose: timer.printElapsedTime('Deleted %d points' % (delete.sum())); timer.reset(); if l < length: skel_flat[e_pts] = False; # remove endpoints for new neighbour detection critical_points.append(e_pts_new); e_pts = e_pts_new; if verbose: timer.printElapsedTime('Cleanup iteration %d / %d done.' % (l, length)); #gather all points if len(delete_points) > 0: delete_points = np.hstack(delete_points); delete_points = np.unique(delete_points); else: delete_points = np.zeros(0); if verbose: timer_all.printElapsedTime('Cleanup'); if clean: skel_flat = np.reshape(skeleton, -1, order = 'F'); skel_flat[delete_points] = False; keep_ids = np.logical_not(np.in1d(points, delete_points, assume_unique = True)) points = points[keep_ids]; radii = radii[keep_ids]; return skeleton, points, radii return delete_points;
############################################################################### ### Tests ############################################################################### def _test(): """Test""" pass #%% # import numpy as np # import ClearMap.Visualization.Plot3d as p3d; # import ClearMap.DataProcessing.ConvolvePointList as cpl; # import ClearMap.ImageProcessing.Skeletonization.SkeletonCleanUp as scu # reload(scu); # # data = np.load('/home/ckirst/Desktop/data.npy'); # skel = np.load('/home/ckirst/Desktop/skel.npy'); # points = np.load('/home/ckirst/Desktop/pts.npy'); # # #data = data[:50,:50,:50]; # #skel = skel[:50,:50,:50]; # #t3d.deleteBorder(skel); # #points = np.where(np.reshape(skel,-1))[0]; # skelfor = np.asarray(skel, order = 'F'); # skelfor_copy = np.asarray(skel, order = 'F'); # points = np.ravel_multi_index(np.where(skelfor), skelfor.shape, order = 'F'); # # #%% # reload(scu); # # clean = scu.cleanOpenBranches(skelfor, skelfor_copy, points = points, length = 3, clean=False, verbose = True) # # skel_clean = np.zeros_like(skelfor, dtype = bool, order = 'A'); # skel_clean_f = np.reshape(skel_clean, -1, order = 'A'); # skel_clean_f[clean] = True; # # deg = cpl.convolve_3d_indices(skelfor, t3d.n26, points); # branchpoints = points[deg >= 3]; # skel_3 = np.zeros_like(skelfor, dtype = bool, order = 'A'); # skel_3_f = np.reshape(skel_3, -1, order = 'A'); # skel_3_f[branchpoints] = True; # # endpoints = points[deg ==1]; # skel_1 = np.zeros_like(skelfor, dtype = bool, order = 'A'); # skel_1_f = np.reshape(skel_1, -1, order = 'A'); # skel_1_f[endpoints] = True; # # #data_f = np.reshape(data, -1); # #data_f[points] = 160; # p3d.multi_plot([np.asarray(skel, dtype = int) + 2 * skel_3 + 4 * skel_1, # np.asarray(skel, dtype = int) + 2 * skel_clean, minMax = [0,7]]); # #%% #clean2 = scu.cleanOpenBranches2(skel.copy(), points = points, length = 3, clean=False, verbose = True) # # skel_clean2 = np.zeros_like(skelfor, dtype = bool, order = 'A'); # skel_clean2_f = np.reshape(skel_clean2, -1, order = 'A'); # skel_clean2_f[clean2] = True; #def cleanOpenBranches2(skeleton, points, length, clean = False, processes = cpu_count(), verbose = False): # """Remove open branches of length smaller than sepcifed # # Arguments # --------- # skeleton : array # binary 3d image of skeleton. # points : array # flat indices of non-zero entries in skeleton # length : int # maximal branch length to remove # # Returns # ------- # cleaned : array # flat indices of cleaned skeleton # """ # timer = tmr.Timer(); # timer_all = tmr.Timer(); # # #prepare # kernel = np.ones((3,3,3), dtype = bool); # kernel[1,1,1] = False; # deg = cpl.convolve_3d_indices(skeleton, kernel, points); # # if verbose: # timer.printElapsedTime('Degree calcualtion'); # timer.reset(); # # branchpoints = deg >= 3; # branchpoints = branchpoints.view(dtype = 'uint8'); # # if verbose: # timer.printElapsedTime('Branch points'); # timer.reset(); # # endpoints = np.where(deg == 1)[0]; # # if verbose: # timer.printElapsedTime('Branch and end point detection'); # timer.reset(); # # # keep = code.cleanBranchesIndex(points, strides = np.array(skeleton.strides), # length = length, # startpoints = endpoints, stoppoints = branchpoints, # processes = processes); # keep = keep.view(dtype = bool); # delete = points[np.logical_not(keep)]; # # if verbose: # timer.printElapsedTime('Cleanup detection'); # # if clean: # skeleton_flat = np.reshape(skeleton, -1); # skeleton_flat[delete] = False; # delete = skeleton, delete; # # if verbose: # timer_all.printElapsedTime('Cleanup'); # # return delete; #%% # # ss = con.extractNeighbourhood(skel, [84, 125, 58], 5); # ss = t3d.deleteBorder(ss); # pp = np.where(np.reshape(ss, -1))[0]; # # ee, ii = con.findEndpoints(ss, pp, border = None) # # img = np.asarray(ss, dtype = int); # imgf = np.reshape(img, -1); # imgf[pp] = cpl.convolve_3d)indices(ss, t3d.n26, pp) # # ee_xyz = np.array(np.unravel_index(ee, ss.shape)).T; # pp_xyz = np.array(np.unravel_index(pp, ss.shape)).T; # # kernel = np.ones((3,3,3), dtype = bool); # kernel[1,1,1] = False; # deg = cpl.convolve3DIndex(ss, kernel, pp); # branchpoints = deg >= 3; # # s2, p2 = scu.cleanOpenBranches(ss, pp, length = 1); # # dv.dualPlot(img, s2)