Source code for ClearMap.ImageProcessing.Skeletonization.PK12

# -*- coding: utf-8 -*-
"""
3d Skeletonization PK12
=======================

This module implements the 3d parallel 12-subiteration thinning algorithm 
by Palagy & Kuba via parallel convolution of the data with a base template
and lookup table matching.

Reference
---------
  Palagyi & Kuba, A Parallel 3D 12-Subiteration Thinning Algorithm, 
  Graphical Models and Image Processing 61, 199-221 (1999)
"""
__author__    = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__   = 'GPLv3 - GNU General Pulic License v3 (see LICENSE.txt)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__   = 'http://idisco.info'
__download__  = 'http://www.github.com/ChristophKirst/ClearMap2'


import os
import numpy as np
import multiprocessing as mp

import ClearMap.ImageProcessing.Topology.Topology3d as t3d
 
import ClearMap.ParallelProcessing.DataProcessing.ArrayProcessing as ap
import ClearMap.ParallelProcessing.DataProcessing.ConvolvePointList as cpl

import ClearMap.Utils.Timer as tmr

import ClearMap.IO.FileUtils as fu


###############################################################################
### Topology
###############################################################################

[docs]def match(cube): """Match one of the masks in the algorithm. Arguments --------- cube : 3x3x3 bool array The local binary image. Returns ------- match : bool True if one of the masks matches Note ---- Algorithm as in Palagyi & Kuba (1999) """ #T1 T1 = (cube[1,1,0] & cube[1,1,1] & (cube[0,0,0] or cube[1,0,0] or cube[2,0,0] or cube[0,1,0] or cube[2,1,0] or cube[0,2,0] or cube[1,2,0] or cube[2,2,0] or cube[0,0,1] or cube[1,0,1] or cube[2,0,1] or cube[0,1,1] or cube[2,1,1] or cube[0,2,1] or cube[1,2,1] or cube[2,2,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[0,1,2]) & (not cube[1,1,2]) & (not cube[2,1,2]) & (not cube[0,2,2]) & (not cube[1,2,2]) & (not cube[2,2,2])); if T1: return True; #T2 T2 = (cube[1,1,1] & cube[1,2,1] & (cube[0,1,0] or cube[1,1,0] or cube[2,1,0] or cube[0,2,0] or cube[1,2,0] or cube[2,2,0] or cube[0,1,1] or cube[2,1,1] or cube[0,2,1] or cube[2,2,1] or cube[0,1,2] or cube[1,1,2] or cube[2,1,2] or cube[0,2,2] or cube[1,2,2] or cube[2,2,2]) & (not cube[0,0,0]) & (not cube[1,0,0]) & (not cube[2,0,0]) & (not cube[0,0,1]) & (not cube[1,0,1]) & (not cube[2,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2])); if T2: return True; #T3 T3 = (cube[1,1,1] & cube[1,2,0] & (cube[0,1,0] or cube[2,1,0] or cube[0,2,0] or cube[2,2,0] or cube[0,1,1] or cube[2,1,1] or cube[0,2,1] or cube[2,2,1]) & (not cube[0,0,0]) & (not cube[1,0,0]) & (not cube[2,0,0]) & (not cube[0,0,1]) & (not cube[1,0,1]) & (not cube[2,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[0,1,2]) & (not cube[1,1,2]) & ( not cube[2,1,2]) & (not cube[0,2,2]) & (not cube[1,2,2]) & (not cube[2,2,2])); if T3: return True; #T4 T4 = (cube[1,1,0] & cube[1,1,1] & cube[1,2,1] & ((not cube[0,0,1]) or (not cube[0,1,2])) & ((not cube[2,0,1]) or (not cube[2,1,2])) & (not cube[1,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[1,1,2])); if T4: return True; #T5 T5 = (cube[1,1,0] & cube[1,1,1] & cube[1,2,1] & cube[2,0,2] & ((not cube[0,0,1]) or (not cube[0,1,2])) & (((not cube[2,0,1]) & cube[2,1,2]) or (cube[2,0,1] & (not cube[2,1,2]))) & (not cube[1,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[1,1,2])); if T5: return True; #T6 T6 = (cube[1,1,0] & cube[1,1,1] & cube[1,2,1] & cube[0,0,2] & ((not cube[2,0,1]) or (not cube[2,1,2])) & (((not cube[0,0,1]) & cube[0,1,2]) or (cube[0,0,1] & (not cube[0,1,2]))) & (not cube[1,0,1]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[1,1,2])); if T6: return True; #T7 T7 = (cube[1,1,0] & cube[1,1,1] & cube[2,1,1] & cube[1,2,1] & ((not cube[0,0,1]) or (not cube[0,1,2])) & (not cube[1,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[1,1,2])); if T7: return True; #T8 T8 = (cube[1,1,0] & cube[0,1,1] & cube[1,1,1] & cube[1,2,1] & ((not cube[2,0,1]) or (not cube[2,1,2])) & (not cube[1,0,1]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[1,1,2])); if T8: return True; #T9 T9 = (cube[1,1,0] & cube[1,1,1] & cube[2,1,1] & cube[0,0,2] & cube[1,2,1] & (((not cube[0,0,1]) & cube[0,1,2]) or (cube[0,0,1] & (not cube[0,1,2]))) & (not cube[1,0,1]) & (not cube[1,0,2]) & (not cube[1,1,2])); if T9: return True; #T10 T10= (cube[1,1,0] & cube[0,1,1] & cube[1,1,1] & cube[2,0,2] & cube[1,2,1] & (((not cube[2,0,1]) & cube[2,1,2]) or (cube[2,0,1] & (not cube[2,1,2]))) & (not cube[1,0,1]) & (not cube[1,0,2]) & (not cube[1,1,2])); if T10: return True; #T11 T11= (cube[2,1,0] & cube[1,1,1] & cube[1,2,0] & (not cube[0,0,0]) & (not cube[1,0,0]) & (not cube[0,0,1]) & (not cube[1,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[0,1,2]) & (not cube[1,1,2]) & (not cube[2,1,2]) & (not cube[0,2,2]) & (not cube[1,2,2]) & (not cube[2,2,2])); if T11: return True; #T12 T12= (cube[0,1,0] & cube[1,2,0] & cube[1,1,1] & (not cube[1,0,0]) & (not cube[2,0,0]) & (not cube[1,0,1]) & (not cube[2,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[0,1,2]) & (not cube[1,1,2]) & (not cube[2,1,2]) & (not cube[0,2,2]) & (not cube[1,2,2]) & (not cube[2,2,2])); if T12: return True; #T13 T13= (cube[1,2,0] & cube[1,1,1] & cube[2,2,1] & (not cube[0,0,0]) & (not cube[1,0,0]) & (not cube[2,0,0]) & (not cube[0,0,1]) & (not cube[1,0,1]) & (not cube[2,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[0,1,2]) & (not cube[1,1,2]) & (not cube[0,2,2]) & (not cube[1,2,2])); if T13: return True; #T14 T14= (cube[1,2,0] & cube[1,1,1] & cube[0,2,1] & (not cube[0,0,0]) & (not cube[1,0,0]) & (not cube[2,0,0]) & (not cube[0,0,1]) & (not cube[1,0,1]) & (not cube[2,0,1]) & (not cube[0,0,2]) & (not cube[1,0,2]) & (not cube[2,0,2]) & (not cube[1,1,2]) & (not cube[2,1,2]) & (not cube[1,2,2]) & (not cube[2,2,2])); if T14: return True; return False;
[docs]def match_index(index, verbose = True): if verbose and index % 2**14 == 0: print('PK12 LUT: %d / %d' % (index, 2**26)); cube = t3d.cube_from_index(index=index, center=True); return match(cube);
[docs]def match_non_removable(index, verbose = True): if verbose and index % 2**14 == 0: print('PK12 LUT non-removables: %d / %d' % (index, 2**26)); cube = t3d.cube_from_index(index=index, center=False); n = cube.sum(); if n < 2: return True; if n > 3: return False; x,y,z = np.where(cube); if n == 2: if np.any(np.abs([x[1]-x[0], y[1]-y[0], z[1]-z[0]]) == 2): return True; else: return False; else: if np.any(np.abs([x[1]-x[0], y[1]-y[0], z[1]-z[0]]) == 2) and np.any(np.abs([x[2]-x[0], y[2]-y[0], z[2]-z[0]]) == 2) and np.any(np.abs([x[1]-x[2], y[1]-y[2], z[1]-z[2]]) == 2): return True; else: return False;
[docs]def generate_lookup_table(function = match_index, verbose = True): """Generates lookup table for templates""" pool = mp.Pool(mp.cpu_count()); lut = pool.map(function, range(2**26),chunksize=2**26/8/mp.cpu_count()); return np.array(lut, dtype = bool);
filename = "PK12.npy"; """Filename for the look up table mapping a cube configuration to the deleatability of the center pixel"""
[docs]def initialize_lookup_table(function = match_index, filename = filename): """Initialize the lookup table""" filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), filename); #check if only compressed file exists fu.uncompress(filename) if os.path.exists(filename): return np.load(filename); else: lut = generate_lookup_table(function=function); np.save(filename, lut); return lut;
base = t3d.cube_base_2(center=False); """Base kernel to multiply with cube to obtain index of cube""" delete = initialize_lookup_table(); """Lookup table mapping cube index to its deleteability""" keep = np.logical_not(delete); """Lookup table mapping cube index to its non-deleteability""" filename_non_removable = "PK12nr.npy"; """Filename for the lookup table mapping a cube configuration to the non-removeability of the center pixel""" non_removable = initialize_lookup_table(filename = filename_non_removable, function = match_non_removable); """Lookup table mapping cube index to its non-removeability""" consider = np.logical_not(non_removable); """Lookup table mapping cube index to whether it needs to be considered further""" rotations = t3d.rotations12(base); """Rotations of the base cube for the sub-iterations""" ############################################################################### ### Skeletonization ###############################################################################
[docs]def skeletonize(binary, points = None, steps = None, removals = False, radii = False, check_border = True, delete_border = False, return_points = False, verbose = True): """Skeletonize a binary 3d array using PK12 algorithm. Arguments --------- binary : array Binary image to skeletonize. points : array or None. Optional list of points in the binary to speed up processing. steps : int or None Number of maximal iteration steps (if None maximal reduction). removals : bool If True, returns also the steps at which the pixels in the input data where removed. radii : bool If True, the estimate of the local radius is returned. check_border : bool If True, check if the boder is empty. The algorithm reuqires this. delete_border : bool If True, delete the border. verbose : bool If True print progress info. Returns ------- skeleton : array The skeleton of the binary. points : array The point coordinates of the skeleton nx3 Note ---- The skeletonization is done in place on the binary. Copy the binary if needed for further processing. """ if verbose: print('#############################################################'); print('Skeletonization PK12 [convolution]'); timer = tmr.Timer(); #TODO: make this work for any memmapable source ! if not isinstance(binary, np.ndarray): raise ValueError('Numpy array required for binary in skeletonization!'); if binary.ndim != 3: raise ValueError('The binary array dimension is %d, 3 is required!' % binary.ndim); if delete_border: binary = t3d.delete_border(binary); check_border = False; if check_border: if not t3d.check_border(binary): raise ValueError('The binary array needs to have no points on the border!'); # detect points #points = np.array(np.nonzero(binary)).T; if points is None: points = ap.where(binary).array; if verbose: timer.print_elapsed_time(head='Foreground points: %d' % (points.shape[0],)); if removals is True or radii is True: #birth = np.zeros(binary.shape, dtype = 'uint16'); death = np.zeros(binary.shape, dtype = 'uint16'); with_info = True; else: with_info = False; # iterate if steps is None: steps = -1; step = 1; removed = 0; while True: if verbose: print('#############################################################'); print('Iteration %d' % step); timer_iter = tmr.Timer(); border = cpl.convolve_3d_points(binary, t3d.n6, points) < 6; borderpoints = points[border]; borderids = np.nonzero(border)[0]; keep = np.ones(len(border), dtype = bool); if verbose: timer_iter.print_elapsed_time('Border points: %d' % (len(borderpoints),)); #if info is not None: # b = birth[borderpoints[:,0], borderpoints[:,1], borderpoints[:,2]]; # bids = b == 0; # birth[borderpoints[bids,0], borderpoints[bids,1], borderpoints[bids,2]] = step; # sub iterations remiter = 0; for i in range(12): if verbose: print('-------------------------------------------------------------'); print('Sub-Iteration %d' % i); timer_sub_iter = tmr.Timer(); remborder = delete[cpl.convolve_3d_points(binary, rotations[i], borderpoints)]; rempoints = borderpoints[remborder]; if verbose: timer_sub_iter.print_elapsed_time('Matched points: %d' % (len(rempoints),)); binary[rempoints[:,0], rempoints[:,1], rempoints[:,2]] = 0; keep[borderids[remborder]] = False; rem = len(rempoints); remiter += rem; removed += rem; if verbose: print('Deleted points: %d' % (rem)); timer_sub_iter.print_elapsed_time('Sub-Iteration %d' % (i)); #death times if with_info is True: #remo = np.logical_not(keep); death[rempoints[:,0], rempoints[:,1], rempoints[:,2]] = 12 * step + i; #update foreground points = points[keep]; if verbose: print('Foreground points: %d' % points.shape[0]); if verbose: print('-------------------------------------------------------------'); timer_iter.print_elapsed_time('Iteration %d' % (step,)); step += 1; if steps >= 0 and step >= steps: break if remiter == 0: break if verbose: print('#############################################################'); print('Total removed: %d' % (removed)); print('Total remaining: %d' % (len(points))); timer.print_elapsed_time('Skeletonization'); result = [binary]; if return_points: result.append(points); if removals is True: result.append(death); if radii is True: #calculate average diameter as average death of neighbourhood radii = cpl.convolve_3d(death, np.array(t3d.n18, dtype = 'uint16'), points); result.append(radii); if len(result) > 1: return tuple(result); else: return result[0];
[docs]def skeletonize_index(binary, points = None, steps = None, removals = False, radii = False, return_points = False, check_border = True, delete_border = False, verbose = True): """Skeletonize a binary 3d array using PK12 algorithm via index coordinates. Arguments --------- binary : array Binary image to be skeletonized. steps : int or None Number of maximal iteration steps. If None, use maximal reduction. removals :bool If True, returns the steps in which the pixels in the input data were removed. radii :bool If True, the estimate of the local radius is returned. verbose :bool If True, print progress info. Returns ------- skeleton : array The skeleton of the binary input. points : nxd array The point coordinates of the skeleton. """ if verbose: print('#############################################################'); print('Skeletonization PK12 [convolution, index]'); timer = tmr.Timer(); #TODO: make this work for any memmapable source if not isinstance(binary, np.ndarray): raise ValueError('Numpy array required for binary in skeletonization!'); if binary.ndim != 3: raise ValueError('The binary array dimension is %d, 3 is required!' % binary.ndim); if delete_border: binary = t3d.delete_border(binary); check_border = False; if check_border: if not t3d.check_border(binary): raise ValueError('The binary array needs to have not points on the border!'); binary_flat = binary.reshape(-1, order = 'A'); # detect points if points is None: points = ap.where(binary_flat).array; npoints = points.shape[0]; if verbose: timer.print_elapsed_time('Foreground points: %d' % (points.shape[0],)); if removals is True or radii is True: #birth = np.zeros(binary.shape, dtype = 'uint16'); order = 'C'; if binary.flags.f_contiguous: order = 'F'; death = np.zeros(binary.shape, dtype = 'uint16', order = order); deathflat = death.reshape(-1, order = 'A') with_info = True; else: with_info = False; # iterate if steps is None: steps = -1; step = 1; nnonrem = 0; while True: if verbose: print('#############################################################'); print('Iteration %d' % step); timer_iter = tmr.Timer(); print(type(points), points.dtype, binary.dtype) border = cpl.convolve_3d_indices_if_smaller_than(binary, t3d.n6, points, 6); borderpoints = points[border]; #borderids = np.nonzero(border)[0]; borderids = ap.where(border).array; keep = np.ones(len(border), dtype = bool); if verbose: timer_iter.print_elapsed_time('Border points: %d' % (len(borderpoints),)); #if info is not None: # b = birth[borderpoints[:,0], borderpoints[:,1], borderpoints[:,2]]; # bids = b == 0; # birth[borderpoints[bids,0], borderpoints[bids,1], borderpoints[bids,2]] = step; # sub iterations remiter = 0; for i in range(12): if verbose: print('-------------------------------------------------------------'); print('Sub-Iteration %d' % i); timer_sub_iter = tmr.Timer(); remborder = delete[cpl.convolve_3d_indices(binary, rotations[i], borderpoints)]; rempoints = borderpoints[remborder]; if verbose: timer_sub_iter.print_elapsed_time('Matched points : %d' % (len(rempoints),)); binary_flat[rempoints] = 0; keep[borderids[remborder]] = False; rem = len(rempoints); remiter += rem; #death times if with_info is True: #remo = np.logical_not(keep); deathflat[rempoints] = 12 * step + i; if verbose: timer_sub_iter.print_elapsed_time('Sub-Iteration %d' % (i,)); if verbose: print('-------------------------------------------------------------'); #update foregroud points = points[keep]; if step % 3 == 0: npts = len(points); points = points[consider[cpl.convolve_3d_indices(binary, base, points)]]; nnonrem += npts - len(points) if verbose: print('Non-removable points: %d' % (npts - len(points))); if verbose: print('Foreground points : %d' % points.shape[0]); if verbose: print('-------------------------------------------------------------'); timer_iter.print_elapsed_time('Iteration %d' % (step,)); step += 1; if steps >= 0 and step >= steps: break if remiter == 0: break if verbose: print('#############################################################'); timer.print_elapsed_time('Skeletonization done'); print('Total removed: %d' % (npoints - (len(points) + nnonrem))); print('Total remaining: %d' % (len(points) + nnonrem)); if radii is True or return_points is True: points = ap.where(binary_flat).array if radii is True: #calculate average diameter as death average death of neighbourhood radii = cpl.convolve_3d_indices(death, t3d.n18, points, out_dtype = 'uint16'); else: radii = None; result = [binary]; if return_points: result.append(points); if removals is True: result.append(death); if radii is not None: result.append(radii); if len(result) > 1: return tuple(result); else: return result[0];
############################################################################### ### Tests ############################################################################### def _test(): import numpy as np; import ClearMap.IO.IO as io import ClearMap.Visualization.Plot3d as p3d import ClearMap.Tests.Files as tsf import ClearMap.ImageProcessing.Skeletonization.PK12 as PK12; from importlib import reload reload(PK12); #Lookup tables #lut = PK12.generate_lookup_table(); #np.save(PK12.filename, lut); #lut.sum() #lutnr = PK12.generate_lookup_table(function=PK12.match_non_removable, verbose = True); #np.save(PK12.filename_non_removable, lutnr); #lutnr.sum() #Skeletonization reload(PK12) binary = tsf.skeleton_binary; binary_array = np.array(io.as_source(binary)); #default version skeleton = PK12.skeletonize(binary_array.copy(), delete_border=True, verbose=True); p3d.plot([[binary_array, skeleton]]) #fast index version skeleton = PK12.skeletonize_index(binary_array.copy(), delete_border=True, verbose = True); p3d.plot([[binary_array, skeleton]]) # plotting import ClearMap.Visualization.Plot3d as p3d p3d.plot_3d(binary_array[:150,:150,:150], cmap=p3d.grays_alpha(0.05)); p3d.plot_3d(skeleton[:150,:150,:150], cmap=p3d.single_color_colormap('red', alpha = 0.8)) #save for figure import scipy.io as sio sio.savemat('binary.mat', {'binary' : binary_array[:100,:100,:100]}); sio.savemat('binary_skeleton.mat', {'skeleton': skeleton[:100,:100,:100]});