summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-10-11 16:31:49 +0100
committerGitHub <noreply@github.com>2019-10-11 16:31:49 +0100
commitc7213fdfcd31e6ec780aab4afe1bd34374d784f5 (patch)
treeaa50d68cd0ab5dc96f858cb840a1acf4124cc369
parentb4e242471dd96d3af12d0c4c1d94a60be08dadcc (diff)
downloadframework-c7213fdfcd31e6ec780aab4afe1bd34374d784f5.tar.gz
framework-c7213fdfcd31e6ec780aab4afe1bd34374d784f5.tar.bz2
framework-c7213fdfcd31e6ec780aab4afe1bd34374d784f5.tar.xz
framework-c7213fdfcd31e6ec780aab4afe1bd34374d784f5.zip
Pass kwargs to algorithm (#380)
* add test for algorithm * fix conflict * suppress warning * pass kwargs to Algorithm class creator
-rwxr-xr-xWrappers/Python/ccpi/framework/framework.py8
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py30
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py27
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py30
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py29
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py27
-rwxr-xr-xWrappers/Python/test/test_algorithms.py47
7 files changed, 145 insertions, 53 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py
index 6d5bd1b..ed97862 100755
--- a/Wrappers/Python/ccpi/framework/framework.py
+++ b/Wrappers/Python/ccpi/framework/framework.py
@@ -722,7 +722,8 @@ class DataContainer(object):
return type(self)(out,
deep_copy=False,
dimension_labels=self.dimension_labels,
- geometry=self.geometry)
+ geometry=self.geometry,
+ suppress_warning=True)
elif issubclass(type(out), DataContainer) and issubclass(type(x2), DataContainer):
@@ -800,7 +801,8 @@ class DataContainer(object):
return type(self)(out,
deep_copy=False,
dimension_labels=self.dimension_labels,
- geometry=self.geometry)
+ geometry=self.geometry,
+ suppress_warning=True)
elif issubclass(type(out), DataContainer):
if self.check_dimensions(out):
kwargs['out'] = out.as_array()
@@ -885,7 +887,7 @@ class ImageData(DataContainer):
if not kwargs.get('suppress_warning', False):
warnings.warn('Direct invocation is deprecated and will be removed in following version. Use allocate from ImageGeometry instead',
- DeprecationWarning)
+ DeprecationWarning, stacklevel=4)
self.geometry = kwargs.get('geometry', None)
if array is None:
if self.geometry is not None:
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index d2e5b29..57292df 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -47,20 +47,30 @@ class CGLS(Algorithm):
Reference:
https://web.stanford.edu/group/SOL/software/cgls/
'''
- def __init__(self, **kwargs):
+ def __init__(self, x_init=None, operator=None, data=None, tolerance=1e-6, **kwargs):
+ '''initialisation of the algorithm
+
+ :param operator : Linear operator for the inverse problem
+ :param x_init : Initial guess ( Default x_init = 0)
+ :param data : Acquired data to reconstruct
+ :param tolerance: Tolerance/ Stopping Criterion to end CGLS algorithm
+ '''
+ super(CGLS, self).__init__(**kwargs)
- super(CGLS, self).__init__()
- x_init = kwargs.get('x_init', None)
- operator = kwargs.get('operator', None)
- data = kwargs.get('data', None)
- tolerance = kwargs.get('tolerance', 1e-6)
if x_init is not None and operator is not None and data is not None:
- print(self.__class__.__name__ , "set_up called from creator")
self.set_up(x_init=x_init, operator=operator, data=data, tolerance=tolerance)
def set_up(self, x_init, operator, data, tolerance=1e-6):
-
+ '''initialisation of the algorithm
+
+ :param operator : Linear operator for the inverse problem
+ :param x_init : Initial guess ( Default x_init = 0)
+ :param data : Acquired data to reconstruct
+ :param tolerance: Tolerance/ Stopping Criterion to end CGLS algorithm
+ '''
+ print("{} setting up".format(self.__class__.__name__, ))
+
self.x = x_init * 0.
self.operator = operator
self.tolerance = tolerance
@@ -78,7 +88,9 @@ class CGLS(Algorithm):
self.xmax = self.normx
self.loss.append(self.r.squared_norm())
- self.configured = True
+ self.configured = True
+ print("{} configured".format(self.__class__.__name__, ))
+
def update(self):
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 5d79b67..8c485b7 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -53,24 +53,31 @@ class FISTA(Algorithm):
'''
- def __init__(self, **kwargs):
+ def __init__(self, x_init=None, f=None, g=ZeroFunction(), **kwargs):
- '''creator
+ '''FISTA algorithm creator
initialisation can be done at creation time if all
- proper variables are passed or later with set_up'''
+ proper variables are passed or later with set_up
+
+ :param x_init : Initial guess ( Default x_init = 0)
+ :param f : Differentiable function
+ :param g : Convex function with " simple " proximal operator'''
+
+ super(FISTA, self).__init__(**kwargs)
- super(FISTA, self).__init__()
- f = kwargs.get('f', None)
- g = kwargs.get('g', ZeroFunction())
- x_init = kwargs.get('x_init', None)
-
if x_init is not None and f is not None:
- print(self.__class__.__name__ , "set_up called from creator")
self.set_up(x_init=x_init, f=f, g=g)
def set_up(self, x_init, f, g=ZeroFunction()):
+ '''initialisation of the algorithm
+ :param x_init : Initial guess ( Default x_init = 0)
+ :param f : Differentiable function
+ :param g : Convex function with " simple " proximal operator'''
+
+ print("{} setting up".format(self.__class__.__name__, ))
+
self.y = x_init.copy()
self.x_old = x_init.copy()
self.x = x_init.copy()
@@ -84,6 +91,8 @@ class FISTA(Algorithm):
self.t = 1
self.update_objective()
self.configured = True
+ print("{} configured".format(self.__class__.__name__, ))
+
def update(self):
self.t_old = self.t
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
index f79651a..8f9c958 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
@@ -35,17 +35,20 @@ class GradientDescent(Algorithm):
'''
- def __init__(self, **kwargs):
- '''initialisation can be done at creation time if all
- proper variables are passed or later with set_up'''
- super(GradientDescent, self).__init__()
-
- x_init = kwargs.get('x_init', None)
- objective_function = kwargs.get('objective_function', None)
- rate = kwargs.get('rate', None)
+ def __init__(self, x_init=None, objective_function=None, rate=None, **kwargs):
+ '''GradientDescent algorithm creator
+
+ initialisation can be done at creation time if all
+ proper variables are passed or later with set_up
+
+ :param x_init: initial guess
+ :param objective_function: objective function to be minimised
+ :param rate: step rate
+ '''
+ super(GradientDescent, self).__init__(**kwargs)
+
if x_init is not None and objective_function is not None and rate is not None:
- print(self.__class__.__name__, "set_up called from creator")
self.set_up(x_init=x_init, objective_function=objective_function, rate=rate)
def should_stop(self):
@@ -53,7 +56,13 @@ class GradientDescent(Algorithm):
return self.iteration >= self.max_iteration
def set_up(self, x_init, objective_function, rate):
- '''initialisation of the algorithm'''
+ '''initialisation of the algorithm
+
+ :param x_init: initial guess
+ :param objective_function: objective function to be minimised
+ :param rate: step rate'''
+ print("{} setting up".format(self.__class__.__name__, ))
+
self.x = x_init.copy()
self.objective_function = objective_function
self.rate = rate
@@ -69,6 +78,7 @@ class GradientDescent(Algorithm):
self.x_update = x_init.copy()
self.configured = True
+ print("{} configured".format(self.__class__.__name__, ))
def update(self):
'''Single iteration'''
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 7bc4e11..7ed82b2 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -61,20 +61,31 @@ class PDHG(Algorithm):
SIAM J. Imaging Sci. 3, 1015–1046.
'''
- def __init__(self, **kwargs):
- super(PDHG, self).__init__(max_iteration=kwargs.get('max_iteration',0))
- f = kwargs.get('f', None)
- operator = kwargs.get('operator', None)
- g = kwargs.get('g', None)
- tau = kwargs.get('tau', None)
- sigma = kwargs.get('sigma', 1.)
+ def __init__(self, f=None, g=None, operator=None, tau=None, sigma=1.,**kwargs):
+ '''PDHG algorithm creator
+
+ :param operator : Linear Operator = K
+ :param f : Convex function with "simple" proximal of its conjugate.
+ :param g : Convex function with "simple" proximal
+ :param sigma : Step size parameter for Primal problem
+ :param tau : Step size parameter for Dual problem'''
+ super(PDHG, self).__init__(**kwargs)
+
if f is not None and operator is not None and g is not None:
- print(self.__class__.__name__ , "set_up called from creator")
self.set_up(f=f, g=g, operator=operator, tau=tau, sigma=sigma)
def set_up(self, f, g, operator, tau=None, sigma=1.):
+ '''initialisation of the algorithm
+
+ :param operator : Linear Operator = K
+ :param f : Convex function with "simple" proximal of its conjugate.
+ :param g : Convex function with "simple" proximal
+ :param sigma : Step size parameter for Primal problem
+ :param tau : Step size parameter for Dual problem'''
+ print("{} setting up".format(self.__class__.__name__, ))
+
# can't happen with default sigma
if sigma is None and tau is None:
raise ValueError('Need sigma*tau||K||^2<1')
@@ -108,6 +119,8 @@ class PDHG(Algorithm):
self.theta = 1
self.update_objective()
self.configured = True
+ print("{} configured".format(self.__class__.__name__, ))
+
def update(self):
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
index 8feef87..50398f4 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
@@ -47,19 +47,30 @@ class SIRT(Algorithm):
e.g. x\in[0, 1], IndicatorBox to enforce box constraints
Default is None).
'''
- def __init__(self, **kwargs):
- super(SIRT, self).__init__()
+ def __init__(self, x_init=None, operator=None, data=None, constraint=None, **kwargs):
+ '''SIRT algorithm creator
- x_init = kwargs.get('x_init', None)
- operator = kwargs.get('operator', None)
- data = kwargs.get('data', None)
- constraint = kwargs.get('constraint', None)
+ :param x_init : Initial guess
+ :param operator : Linear operator for the inverse problem
+ :param data : Acquired data to reconstruct
+ :param constraint : Function proximal method
+ e.g. x\in[0, 1], IndicatorBox to enforce box constraints
+ Default is None).'''
+ super(SIRT, self).__init__(**kwargs)
if x_init is not None and operator is not None and data is not None:
- print(self.__class__.__name__, "set_up called from creator")
self.set_up(x_init=x_init, operator=operator, data=data, constraint=constraint)
def set_up(self, x_init, operator, data, constraint=None):
+ '''initialisation of the algorithm
+
+ :param operator : Linear operator for the inverse problem
+ :param x_init : Initial guess
+ :param data : Acquired data to reconstruct
+ :param constraint : Function proximal method
+ e.g. x\in[0, 1], IndicatorBox to enforce box constraints
+ Default is None).'''
+ print("{} setting up".format(self.__class__.__name__, ))
self.x = x_init.copy()
self.operator = operator
@@ -75,6 +86,8 @@ class SIRT(Algorithm):
self.D = 1/self.operator.adjoint(self.operator.range_geometry().allocate(value=1.0))
self.update_objective()
self.configured = True
+ print("{} configured".format(self.__class__.__name__, ))
+
def update(self):
diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py
index 15a83e8..1dd198a 100755
--- a/Wrappers/Python/test/test_algorithms.py
+++ b/Wrappers/Python/test/test_algorithms.py
@@ -75,24 +75,40 @@ class TestAlgorithms(unittest.TestCase):
alg.max_iteration = 20
alg.run(20, verbose=True)
self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
+ alg = GradientDescent(x_init=x_init,
+ objective_function=norm2sq,
+ rate=rate, max_iteration=20,
+ update_objective_interval=2)
+ alg.max_iteration = 20
+ self.assertTrue(alg.max_iteration == 20)
+ self.assertTrue(alg.update_objective_interval==2)
+ alg.run(20, verbose=True)
+ self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
def test_CGLS(self):
print ("Test CGLS")
#ig = ImageGeometry(124,153,154)
ig = ImageGeometry(10,2)
numpy.random.seed(2)
x_init = ig.allocate(0.)
+ b = ig.allocate('random')
# b = x_init.copy()
# fill with random numbers
# b.fill(numpy.random.random(x_init.shape))
- b = ig.allocate()
- bdata = numpy.reshape(numpy.asarray([i for i in range(20)]), (2,10))
- b.fill(bdata)
+ # b = ig.allocate()
+ # bdata = numpy.reshape(numpy.asarray([i for i in range(20)]), (2,10))
+ # b.fill(bdata)
identity = Identity(ig)
alg = CGLS(x_init=x_init, operator=identity, data=b)
alg.max_iteration = 200
alg.run(20, verbose=True)
- self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array(), decimal=4)
+ self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
+
+ alg = CGLS(x_init=x_init, operator=identity, data=b, max_iteration=200, update_objective_interval=2)
+ self.assertTrue(alg.max_iteration == 200)
+ self.assertTrue(alg.update_objective_interval==2)
+ alg.run(20, verbose=True)
+ self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
def test_FISTA(self):
print ("Test FISTA")
@@ -114,6 +130,15 @@ class TestAlgorithms(unittest.TestCase):
alg.max_iteration = 2
alg.run(20, verbose=True)
self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
+
+ alg = FISTA(x_init=x_init, f=norm2sq, g=ZeroFunction(), max_iteration=2, update_objective_interval=2)
+
+ self.assertTrue(alg.max_iteration == 2)
+ self.assertTrue(alg.update_objective_interval==2)
+
+ alg.run(20, verbose=True)
+ self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
+
def test_FISTA_Norm2Sq(self):
print ("Test FISTA Norm2Sq")
@@ -133,6 +158,14 @@ class TestAlgorithms(unittest.TestCase):
alg.max_iteration = 2
alg.run(20, verbose=True)
self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
+
+ alg = FISTA(x_init=x_init, f=norm2sq, g=ZeroFunction(), max_iteration=2, update_objective_interval=3)
+ self.assertTrue(alg.max_iteration == 2)
+ self.assertTrue(alg.update_objective_interval== 3)
+
+ alg.run(20, verbose=True)
+ self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
+
def test_FISTA_catch_Lipschitz(self):
print ("Test FISTA catch Lipschitz")
ig = ImageGeometry(127,139,149)
@@ -242,9 +275,9 @@ class TestAlgorithms(unittest.TestCase):
tau = 1/(sigma*normK**2)
# Setup and run the PDHG algorithm
- pdhg1 = PDHG(f=f1,g=g,operator=operator, tau=tau, sigma=sigma)
- pdhg1.max_iteration = 2000
- pdhg1.update_objective_interval = 200
+ pdhg1 = PDHG(f=f1,g=g,operator=operator, tau=tau, sigma=sigma,
+ max_iteration=2000, update_objective_interval=200)
+
pdhg1.run(1000)
rmse = (pdhg1.get_output() - data).norm() / data.as_array().size