Tried hard to complete the code for diffeomorphic registration of images but evry time I got some challenge to overcome. Please have a look If any one can complete it and run it.
most of it was done except converging criteria. But unnecessarily huge code and bugs made me crazy. Thanks for the help.
symmetric diffeomorphic image registration is better than affine transformation only. It is a non-linear process.
import numpy as np
import numpy.linalg as npl
import abc
from enum import Enum
class VerbosityLevels(Enum):
DEBUG = 0
DIAGNOSE = 1
STATUS = 2
DEFAULT = 3
class RegistrationStages(Enum):
OPT_START = 0
SCALE_START = 1
SCALE_END = 2
ITER_START = 3
ITER_END = 4
OPT_END = 5
class DiffeomorphicRegistration(metaclass=abc.ABCMeta):
def __init__(self, metric=None):
if metric is None:
raise ValueError('The metric cannot be None')
self.metric = metric
self.dim = metric.dim
def set_level_iters(self, level_iters):
self.levels = len(level_iters) if level_iters else 0
self.level_iters = level_iters
class DiffeomorphicRegistration(metaclass=abc.ABCMeta):
@abc.abstractmethod
def optimize(self, static, moving, static_grid2world=None,
moving_grid2world=None, prealign=None):
if self.verbosity == VerbosityLevels.DEBUG:
if prealign is not None:
logger.info("Pre-align: " + str(prealign))
# Initialization
self._init_optimizer(static.astype(np.float32), moving.astype(np.float32),
static_grid2world, moving_grid2world, prealign)
# Optimization loop
while not self._converged():
self._iterate()
# Finalization
self._finalize_optimizer()
# Return the diffeomorphic map
return self.get_map()
@abc.abstractmethod
def get_map(self):
if self.static_to_ref is None:
raise ValueError("Diffeormorphic map cannot be obtained without running the optimizer.")
return self.static_to_ref
import logging
# Create a logger instance
logger = logging.getLogger(__name__)
class SymmetricDiffeomorphicRegistration(DiffeomorphicRegistration):
def __init__(self, metric, max_iterations=1000, level_iters=None, step_length=0.25,
ss_sigma_factor=0.2, opt_tol=1e-5, inv_iter=20,
inv_tol=1e-3, callback=None, verbosity=VerbosityLevels.DEFAULT):
super().__init__() # Remove the metric argument
self.metric = metric
self.max_iterations = max_iterations
self.levels = 0 # Initialize levels attribute
self.niter = 0
if level_iters is None:
level_iters = [100, 100, 25]
self.metric = metric # Set the metric separately
self.level_iters = level_iters
self.step_length = step_length
self.ss_sigma_factor = ss_sigma_factor
self.opt_tol = opt_tol
self.inv_iter = inv_iter
self.inv_tol = inv_tol
self.callback = callback
self.verbosity = verbosity
self.static_to_ref = None # Initialize static_to_ref attribute
self.moving_to_ref = None # Initialize moving_to_ref attribute
self.current_level = 0
def optimize(self, static, moving, static_grid2world=None,
moving_grid2world=None, prealign=None):
if self.verbosity == VerbosityLevels.DEBUG:
if prealign is not None:
logger.info("Pre-align: " + str(prealign))
self._init_optimizer(static.astype(np.float32), moving.astype(np.float32),
static_grid2world, moving_grid2world, prealign)
# Optimization loop
while not self._converged():
self._iterate()
# Finalization
self._finalize_optimizer()
# Ensure forward and backward transformations are set
self.static_to_ref.compute_forward()
self.moving_to_ref.compute_forward()
self.static_to_ref.compute_backward()
self.moving_to_ref.compute_backward()
# Compute inversion error
residual, stats = self.static_to_ref.compute_inversion_error()
if self.verbosity >= VerbosityLevels.DIAGNOSE:
logger.info('Static-Reference Residual error: %0.6f (%0.6f)'
% (stats[1], stats[2]))
residual, stats = self.moving_to_ref.compute_inversion_error()
if self.verbosity >= VerbosityLevels.DIAGNOSE:
logger.info('Moving-Reference Residual error: %0.6f (%0.6f)'
% (stats[1], stats[2]))
# Return the diffeomorphic map
return self.static_to_ref
def _converged(self): #########################################################
# Check if the number of iterations has reached a predefined limit
return self.niter >= self.max_iterations
def get_map(self):
if self.static_to_ref is None:
raise ValueError("Diffeormorphic map cannot be obtained without running the optimizer.")
return self.static_to_ref
def _init_optimizer(self, static, moving, static_grid2world,
moving_grid2world, prealign):
# Initialize the optimizer here
self.static_image = static
self.moving_image = moving
self.static_grid2world = static_grid2world
self.moving_grid2world = moving_grid2world
self.prealign = prealign
# Initialize other necessary components
self.static_ss = MultiscaleImage(static, static_grid2world)
self.moving_ss = MultiscaleImage(moving, moving_grid2world)
self.current_level = self.static_ss.levels - 1
# Initialize transformation models
self.static_to_ref = DiffeomorphicMap(self.static_ss.get_domain_grid(self.current_level),
self.static_ss.get_affine(self.current_level))
self.moving_to_ref = DiffeomorphicMap(self.moving_ss.get_domain_grid(self.current_level),
self.moving_ss.get_affine(self.current_level))
# self.current_level = 0
# Initialize other parameters and settings
# For example:
# self.metric.set_static_image()
# self.metric.set_moving_image()
# Additional initialization steps as needed
def _end_optimizer(self):
del self.moving_ss
del self.static_ss
def _iterate(self):
current_moving = self.moving_ss.get_image(self.current_level)
current_static = self.static_ss.get_image(self.current_level)
current_disp_shape = self.static_ss.get_domain_shape(self.current_level)
current_disp_grid2world = self.static_ss.get_affine(self.current_level)
current_disp_world2grid = self.static_ss.get_affine_inv(self.current_level)
current_disp_spacing = self.static_ss.get_spacing(self.current_level)
self.static_to_ref.set_backward_transform(lambda x: x) # Placeholder transformation
self.moving_to_ref.set_backward_transform(lambda x: x)
wstatic = self.static_to_ref.transform_inverse(current_static, order=1, mode='constant', cval=0.0)
# wstatic = self.static_to_ref.transform_inverse(current_static, 'linear', None, current_disp_shape, current_disp_grid2world)
wmoving = self.moving_to_ref.transform_inverse(current_moving, 'linear', None, current_disp_shape, current_disp_grid2world)
# Optimization logic here...
# Example of incrementing iterations and managing levels
self.niter += 1
if self.niter % self.level_iters[self.current_level] == 0:
self.current_level += 1
self.metric.set_moving_image(wmoving, current_disp_grid2world, current_disp_spacing, self.static_direction)
self.metric.use_moving_image_dynamics(current_moving, self.moving_to_ref.inverse())
self.metric.set_static_image(wstatic, current_disp_grid2world, current_disp_spacing, self.static_direction)
self.metric.use_static_image_dynamics(current_static, self.static_to_ref.inverse())
self.metric.initialize_iteration()
if self.callback is not None:
self.callback(self, RegistrationStages.ITER_START)
fw_step = np.array(self.metric.compute_forward())
fw_step = self.__set_no_boundary_displacement(fw_step)
nrm = np.sqrt(np.sum((fw_step/current_disp_spacing)**2, -1)).max()
if nrm > 0:
fw_step /= nrm
self.static_to_ref.forward, md_forward = self.update(self.static_to_ref.forward, fw_step, current_disp_world2grid, self.step_length)
del fw_step
fw_energy = self.metric.get_energy()
bw_step = np.array(self.metric.compute_backward())
bw_step = self.__set_no_boundary_displacement(bw_step)
nrm = np.sqrt(np.sum((bw_step/current_disp_spacing) ** 2, -1)).max()
if nrm > 0:
bw_step /= nrm
self.moving_to_ref.forward, md_backward = self.update(self.moving_to_ref.forward, bw_step, current_disp_world2grid, self.step_length)
del bw_step
bw_energy = self.metric.get_energy()
der = np.inf
n_iter = len(self.energy_list)
if len(self.energy_list) >= self.energy_window:
der = self._get_energy_derivative()
self.energy_list.append(fw_energy + bw_energy)
self.__invert_models(current_disp_world2grid, current_disp_spacing)
if self.callback is not None:
self.callback(self, RegistrationStages.ITER_END)
self.metric.free_iteration()
return der
def __set_no_boundary_displacement(self, step):
step[0, ...] = 0
step[:, 0, ...] = 0
step[-1, ...] = 0
step[:, -1, ...] = 0
if self.dim == 3:
step[:, :, 0, ...] = 0
step[:, :, -1, ...] = 0
return step
def __invert_models(self, current_disp_world2grid, current_disp_spacing):
self.static_to_ref.backward = np.array(self.invert_vector_field(self.static_to_ref.forward, current_disp_world2grid, current_disp_spacing, self.inv_iter, self.inv_tol))
self.moving_to_ref.backward = np.array(self.invert_vector_field(self.moving_to_ref.forward, current_disp_world2grid, current_disp_spacing, self.inv_iter, self.inv_tol, self.moving_to_ref.backward))
self.static_to_ref.forward = np.array(self.invert_vector_field(self.static_to_ref.backward, current_disp_world2grid, current_disp_spacing, self.inv_iter, self.inv_tol, self.static_to_ref.forward))
self.moving_to_ref.forward = np.array(self.invert_vector_field(self.moving_to_ref.backward, current_disp_world2grid, current_disp_spacing, self.inv_iter, self.inv_tol, self.moving_to_ref.forward))
def _approximate_derivative_direct(self, x, y):
x = np.asarray(x)
y = np.asarray(y)
X = np.row_stack((x**2, x, np.ones_like(x)))
XX = X.dot(X.T)
b = X.dot(y)
beta = npl.solve(XX, b)
x0 = 0.5 * len(x)
y0 = 2.0 * beta[0] * x0 + beta[1]
return y0
def _get_energy_derivative(self):
n_iter = len(self.energy_list)
if n_iter < self.energy_window:
raise ValueError('Not enough data to fit the energy profile')
x = range(self.energy_window)
y = self.energy_list[(n_iter - self.energy_window):n_iter]
ss = sum(y)
if ss != 0: # avoid division by zero
ss = -ss if ss > 0 else ss
y = [v / ss for v in y]
der = self._approximate_derivative_direct(x, y)
return der
def _optimize(self):
self.full_energy_profile = []
if self.callback is not None:
self.callback(self, RegistrationStages.OPT_START)
for level in range(self.levels - 1, -1, -1):
if self.verbosity >= VerbosityLevels.STATUS:
logger.info('Optimizing level %d' % level)
self.current_level = level
if level < self.levels - 1:
expand_factors = self.static_ss.get_expand_factors(level+1, level)
new_shape = self.static_ss.get_domain_shape(level)
self.static_to_ref.expand_fields(expand_factors, new_shape)
self.moving_to_ref.expand_fields(expand_factors, new_shape)
self.niter = 0
self.energy_list = []
derivative = np.inf
if self.callback is not None:
self.callback(self, RegistrationStages.SCALE_START)
while ((self.niter < self.level_iters[self.levels - 1 - level]) and
(self.opt_tol < derivative)):
derivative = self._iterate()
self.niter += 1
self.full_energy_profile.extend(self.energy_list)
if self.callback is not None:
self.callback(self, RegistrationStages.SCALE_END)
residual, stats = self.static_to_ref.compute_inversion_error()
if self.verbosity >= VerbosityLevels.DIAGNOSE:
logger.info('Static-Reference Residual error: %0.6f (%0.6f)'
% (stats[1], stats[2]))
residual, stats = self.moving_to_ref.compute_inversion_error()
if self.verbosity >= VerbosityLevels.DIAGNOSE:
logger.info('Moving-Reference Residual error :%0.6f (%0.6f)'
% (stats[1], stats[2]))
self.static_to_ref = self.moving_to_ref.warp_endomorphism(
self.static_to_ref.inverse()).inverse()
residual, stats = self.static_to_ref.compute_inversion_error()
if self.verbosity >= VerbosityLevels.DIAGNOSE:
logger.info('Final residual error: %0.6f (%0.6f)' % (stats[1],
stats[2]))
if self.callback is not None:
self.callback(self, RegistrationStages.OPT_END)
class SimpleSimilarityMetric:
def __init__(self, dim=None):
self.dim = dim
self.mask0 = None # Add mask0 attribute and initialize it to None
def set_static_image(self, static_image):
self.static_image = static_image
if self.dim is None:
self.dim = static_image.ndim
def set_moving_image(self, moving_image):
self.moving_image = moving_image
if self.dim is None:
self.dim = moving_image.ndim
def set_mask0(self, mask0): # Add method to set mask0
self.mask0 = mask0
def evaluate(self):
ssd = np.sum((self.static_image - self.moving_image)**2)
return ssd
class MultiscaleImage:
def __init__(self, image, grid2world):
self.image = image
self.grid2world = grid2world
self.levels = 3 # Example: Set the number of levels for demonstration
def get_domain_grid(self, level):
# Assume each level is a downsampled version of the original image
# and the grid is a tuple (x, y) representing the domain grid
# We'll just return a
# simple grid for demonstration purposes
if level == 0:
return np.mgrid[0:self.image.shape[0], 0:self.image.shape[1]]
else:
# Downsample the grid based on the level
downsample_factor = 2 ** level
downsampled_shape = (self.image.shape[0] // downsample_factor, self.image.shape[1] // downsample_factor)
return np.mgrid[0:downsampled_shape[0], 0:downsampled_shape[1]]
def get_affine(self, level):
# Assume each level has a corresponding affine transformation
# We'll just return a simple identity matrix for demonstration purposes
return np.eye(3)
def get_affine_inv(self, level):
"""
Get the inverse affine transformation matrix at the specified level.
Parameters:
- level (int): The level for which to get the inverse affine transformation matrix.
Returns:
- ndarray: The inverse affine transformation matrix at the specified level.
"""
# Ensure level is within valid range
if level < 0 or level >= self.levels:
raise ValueError("Invalid level specified.")
# Example implementation: Return the inverse of the grid2world transformation matrix
return np.linalg.inv(self.grid2world)
def get_domain_shape(self, level):
"""
Get the shape of the image domain at the specified level.
Parameters:
- level (int): The level for which to get the domain shape.
Returns:
- tuple: The shape of the image domain at the specified level.
"""
# Ensure level is within valid range
if level < 0 or level >= self.levels:
raise ValueError("Invalid level specified.")
# Return the shape of the image at the specified level
return self.image.shape
def get_spacing(self, level):
"""
Get the spacing (voxel size) at the specified level.
Parameters:
- level (int): The level for which to get the spacing.
Returns:
- tuple: The spacing (voxel size) at the specified level.
"""
# Ensure level is within valid range
if level < 0 or level >= self.levels:
raise ValueError("Invalid level specified.")
def get_image(self, level):
"""
Get the image at the specified level.
Parameters:
- level (int): The level of the image to retrieve.
Returns:
- ndarray: The image at the specified level.
"""
# Ensure level is within valid range
if level < 0 or level >= self.levels:
raise ValueError("Invalid level specified.")
# Generate Gaussian pyramid
pyramid = self.gaussian_pyramid(levels=self.levels, sigmas=[1.0] * self.levels, down_facts=[2] * self.levels)
# Return the image at the specified level
return pyramid[level]
def gaussian_pyramid(self, levels, sigmas, down_facts):
new_imgs = [self.image]
for i in range(levels):
image = new_imgs[-1]
smooth_img = gaussian_filter(image, sigma=sigmas[i])
small_img = downsample(smooth_img, factor= down_facts[i] )
new_imgs.append(small_img)
return new_imgs
import numpy as np
from scipy.ndimage import map_coordinates
import numpy as np
from scipy.ndimage import gaussian_filter, zoom