summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-10-09 12:30:51 +0100
committerGitHub <noreply@github.com>2019-10-09 12:30:51 +0100
commit743d40f9b98c3b15475fafdee26c9290833f3388 (patch)
treecb9bdd2a456c875fd6ee38545b7a2d20074455d5
parent03dbf855b7ec87cf27dfc9f94c4d12eb24faf491 (diff)
downloadframework-743d40f9b98c3b15475fafdee26c9290833f3388.tar.gz
framework-743d40f9b98c3b15475fafdee26c9290833f3388.tar.bz2
framework-743d40f9b98c3b15475fafdee26c9290833f3388.tar.xz
framework-743d40f9b98c3b15475fafdee26c9290833f3388.zip
Finite diff for sirf (#367)
* python2 compatibility import future * add staticmethod dot to test LinearOperators * a little more efficient code * skips all tests if module wget is not present * removed sirf import and simplified code
-rwxr-xr-xWrappers/Python/ccpi/framework/BlockDataContainer.py42
-rwxr-xr-xWrappers/Python/ccpi/framework/TestData.py7
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py5
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py5
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py7
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/Function.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py18
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/IndicatorBox.py7
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L1Norm.py7
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py5
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py45
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/Norm2Sq.py28
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/ScaledFunction.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/ZeroFunction.py5
-rwxr-xr-xWrappers/Python/ccpi/optimisation/operators/BlockOperator.py42
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/BlockScaledOperator.py7
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/FiniteDifferenceOperator.py25
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py5
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/IdentityOperator.py7
-rwxr-xr-xWrappers/Python/ccpi/optimisation/operators/LinearOperator.py52
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py6
-rwxr-xr-xWrappers/Python/ccpi/optimisation/operators/Operator.py6
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/ShrinkageOperator.py7
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/SparseFiniteDiff.py7
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/SymmetrizedGradientOperator.py11
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/ZeroOperator.py7
-rwxr-xr-xWrappers/Python/test/test_NexusReader.py149
31 files changed, 417 insertions, 125 deletions
diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
index b5116e5..8247f24 100755
--- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py
+++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
@@ -26,7 +26,8 @@ import functools
from ccpi.framework import DataContainer
#from ccpi.framework import AcquisitionData, ImageData
#from ccpi.optimisation.operators import Operator, LinearOperator
-
+
+
class BlockDataContainer(object):
'''Class to hold DataContainers as column vector
@@ -102,7 +103,10 @@ class BlockDataContainer(object):
raise ValueError('List/ numpy array can only contain numbers {}'\
.format(type(ot)))
return len(self.containers) == len(other)
- elif issubclass(other.__class__, DataContainer):
+ elif isinstance(other, BlockDataContainer):
+ return len(self.containers) == len(other.containers)
+ else:
+ # this should work for other as DataContainers and children
ret = True
for i, el in enumerate(self.containers):
if isinstance(el, BlockDataContainer):
@@ -110,9 +114,9 @@ class BlockDataContainer(object):
else:
a = el.shape == other.shape
ret = ret and a
+ # probably will raise
return ret
- #return self.get_item(0).shape == other.shape
- return len(self.containers) == len(other.containers)
+
def get_item(self, row):
if row > self.shape[0]:
@@ -180,7 +184,7 @@ class BlockDataContainer(object):
if not self.is_compatible(other):
raise ValueError('Incompatible for divide')
out = kwargs.get('out', None)
- if isinstance(other, Number) or issubclass(other.__class__, DataContainer):
+ if isinstance(other, Number):
# try to do algebra with one DataContainer. Will raise error if not compatible
kw = kwargs.copy()
res = []
@@ -240,8 +244,32 @@ class BlockDataContainer(object):
return type(self)(*res, shape=self.shape)
return type(self)(*[ operation(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
else:
- raise ValueError('Incompatible type {}'.format(type(other)))
-
+ # try to do algebra with one DataContainer. Will raise error if not compatible
+ kw = kwargs.copy()
+ res = []
+ for i,el in enumerate(self.containers):
+ if operation == BlockDataContainer.ADD:
+ op = el.add
+ elif operation == BlockDataContainer.SUBTRACT:
+ op = el.subtract
+ elif operation == BlockDataContainer.MULTIPLY:
+ op = el.multiply
+ elif operation == BlockDataContainer.DIVIDE:
+ op = el.divide
+ elif operation == BlockDataContainer.POWER:
+ op = el.power
+ else:
+ raise ValueError('Unsupported operation', operation)
+ if out is not None:
+ kw['out'] = out.get_item(i)
+ op(other, *args, **kw)
+ else:
+ res.append(op(other, *args, **kw))
+ if out is not None:
+ return
+ else:
+ return type(self)(*res, shape=self.shape)
+
def power(self, other, *args, **kwargs):
if not self.is_compatible(other):
diff --git a/Wrappers/Python/ccpi/framework/TestData.py b/Wrappers/Python/ccpi/framework/TestData.py
index 2f4c685..74d37be 100755
--- a/Wrappers/Python/ccpi/framework/TestData.py
+++ b/Wrappers/Python/ccpi/framework/TestData.py
@@ -15,6 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.framework import ImageData, ImageGeometry, DataContainer
import numpy
import numpy as np
@@ -334,4 +339,4 @@ class TestData(object):
if clip:
out = np.clip(out, low_clip, 1.0)
- return out \ No newline at end of file
+ return out
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index c6c1d4c..d2e5b29 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -17,6 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.algorithms import Algorithm
import numpy
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index e23116b..5d79b67 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -17,6 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.algorithms import Algorithm
from ccpi.optimisation.functions import ZeroFunction
import numpy
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
index d060690..f79651a 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
@@ -20,6 +20,11 @@
#
#=========================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.algorithms import Algorithm
class GradientDescent(Algorithm):
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 0968872..7bc4e11 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -17,6 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.algorithms import Algorithm
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
index 2b49ab0..8feef87 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
@@ -20,6 +20,11 @@
#
#=========================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.algorithms import Algorithm
class SIRT(Algorithm):
diff --git a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py
index a6ac66c..ee3ad78 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py
@@ -20,6 +20,11 @@
#
#=========================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.functions import Function
from ccpi.framework import BlockDataContainer
from numbers import Number
@@ -232,4 +237,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/functions/Function.py b/Wrappers/Python/ccpi/optimisation/functions/Function.py
index 7156995..48c6d30 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/Function.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/Function.py
@@ -20,6 +20,11 @@
#
#=========================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
import warnings
from ccpi.optimisation.functions.ScaledFunction import ScaledFunction
diff --git a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
index 58d4f27..4162134 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
@@ -39,8 +39,11 @@ class FunctionOperatorComposition(Function):
self.function = function
self.operator = operator
- self.L = function.L * operator.norm()**2
-
+ try:
+ self.L = function.L * operator.norm()**2
+ except Error as er:
+ self.L = None
+ warnings.warn("Lipschitz constant was not calculated")
def __call__(self, x):
@@ -56,12 +59,13 @@ class FunctionOperatorComposition(Function):
'''
+ tmp = self.operator.range_geometry().allocate()
+ self.operator.direct(x, out=tmp)
+ self.function.gradient(tmp, out=tmp)
if out is None:
- return self.operator.adjoint(self.function.gradient(self.operator.direct(x)))
+ #return self.operator.adjoint(self.function.gradient(self.operator.direct(x)))
+ return self.operator.adjoint(tmp)
else:
- tmp = self.operator.range_geometry().allocate()
- self.operator.direct(x, out=tmp)
- self.function.gradient(tmp, out=tmp)
self.operator.adjoint(tmp, out=out)
@@ -122,4 +126,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/functions/IndicatorBox.py b/Wrappers/Python/ccpi/optimisation/functions/IndicatorBox.py
index 51d08d1..fd34d96 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/IndicatorBox.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/IndicatorBox.py
@@ -16,6 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.functions import Function
import numpy
@@ -145,4 +150,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
index f88c339..d71f597 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
@@ -17,6 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
import numpy
from ccpi.optimisation.functions import Function
from ccpi.optimisation.functions.ScaledFunction import ScaledFunction
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L1Norm.py b/Wrappers/Python/ccpi/optimisation/functions/L1Norm.py
index cc4bef8..09e550e 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L1Norm.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L1Norm.py
@@ -17,6 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.functions import Function
from ccpi.optimisation.functions.ScaledFunction import ScaledFunction
from ccpi.optimisation.operators import ShrinkageOperator
@@ -172,4 +177,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
index a625f07..92e0116 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
@@ -18,6 +18,11 @@
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.functions import Function
from ccpi.optimisation.functions.ScaledFunction import ScaledFunction
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
index 8cbed67..378cbda 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
@@ -18,10 +18,16 @@
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.functions import Function, ScaledFunction
from ccpi.framework import BlockDataContainer
import functools
+import numpy
class MixedL21Norm(Function):
@@ -45,11 +51,11 @@ class MixedL21Norm(Function):
'''
if not isinstance(x, BlockDataContainer):
raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))
-
- tmp = [ el**2 for el in x.containers ]
- res = sum(tmp).sqrt().sum()
+ tmp = x.get_item(0) * 0
+ for el in x.containers:
+ tmp += el.power(2.)
+ return tmp.sqrt().sum()
- return res
def gradient(self, x, out=None):
return ValueError('Not Differentiable')
@@ -84,16 +90,29 @@ class MixedL21Norm(Function):
if out is None:
- tmp = [ el*el for el in x.containers]
- res = sum(tmp).sqrt().maximum(1.0)
- frac = [el/res for el in x.containers]
- return BlockDataContainer(*frac)
+ # tmp = [ el*el for el in x.containers]
+ # res = sum(tmp).sqrt().maximum(1.0)
+ # frac = [el/res for el in x.containers]
+ # return BlockDataContainer(*frac)
+ tmp = x.get_item(0) * 0
+ for el in x.containers:
+ tmp += el.power(2.)
+ tmp.sqrt(out=tmp)
+ tmp.maximum(1.0, out=tmp)
+ frac = [ el.divide(tmp) for el in x.containers ]
+ return BlockDataContainer(*frac)
+
else:
res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
- res = res1.sqrt().maximum(1.0)
- x.divide(res, out=out)
+ if False:
+ res = res1.sqrt().maximum(1.0)
+ x.divide(res, out=out)
+ else:
+ res1.sqrt(out=res1)
+ res1.maximum(1.0, out=res1)
+ x.divide(res1, out=out)
def __rmul__(self, scalar):
@@ -106,6 +125,12 @@ class MixedL21Norm(Function):
return ScaledFunction(self, scalar)
+def sqrt_maximum(x, a):
+ y = numpy.sqrt(x)
+ if y >= a:
+ return y
+ else:
+ return a
#
if __name__ == '__main__':
diff --git a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py
index 0da6e50..01c4f38 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py
@@ -17,7 +17,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+from ccpi.optimisation.operators import LinearOperator
from ccpi.optimisation.functions import Function
import warnings
@@ -50,8 +55,13 @@ class Norm2Sq(Function):
try:
self.L = 2.0*self.c*(self.A.norm()**2)
except AttributeError as ae:
- warnings.warn('{} could not calculate Lipschitz Constant. {}'.format(
+ if self.A.is_linear():
+ Anorm = LinearOperator.PowerMethod(self.A, 10)[0]
+ self.L = 2.0 * self.c * (Anorm*Anorm)
+ else:
+ warnings.warn('{} could not calculate Lipschitz Constant. {}'.format(
self.__class__.__name__, ae))
+
except NotImplementedError as noe:
warnings.warn('{} could not calculate Lipschitz Constant. {}'.format(
self.__class__.__name__, noe))
@@ -65,22 +75,28 @@ class Norm2Sq(Function):
# return self.c*( ( (self.A.direct(x)-self.b)**2).sum() )
#else:
y = self.A.direct(x)
- y.__isub__(self.b)
+ y.subtract(self.b, out=y)
#y.__imul__(y)
#return y.sum() * self.c
try:
+ if self.c == 1:
+ return y.squared_norm()
return y.squared_norm() * self.c
except AttributeError as ae:
- # added for compatibility with SIRF
- return (y.norm()**2) * self.c
+ # added for compatibility with SIRF
+ warnings.warn('squared_norm method not found! Proceeding with norm.')
+ yn = y.norm()
+ if self.c == 1:
+ return yn * yn
+ return (yn * yn) * self.c
def gradient(self, x, out=None):
if out is not None:
#return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b )
self.A.direct(x, out=self.range_tmp)
- self.range_tmp -= self.b
+ self.range_tmp.subtract(self.b , out=self.range_tmp)
self.A.adjoint(self.range_tmp, out=out)
#self.direct_placehold.multiply(2.0*self.c, out=out)
- out *= (self.c * 2.0)
+ out.multiply (self.c * 2.0, out=out)
else:
return (2.0*self.c)*self.A.adjoint(self.A.direct(x) - self.b)
diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
index 3e689e6..a123e8d 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
@@ -17,6 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from numbers import Number
import numpy
import warnings
diff --git a/Wrappers/Python/ccpi/optimisation/functions/ZeroFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ZeroFunction.py
index ca52f31..19db668 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/ZeroFunction.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/ZeroFunction.py
@@ -18,6 +18,11 @@
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.functions import Function
class ZeroFunction(Function):
diff --git a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py
index e3a02ec..23cb799 100755
--- a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py
@@ -17,11 +17,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
import numpy
import functools
from ccpi.framework import ImageData, BlockDataContainer, DataContainer
-from ccpi.optimisation.operators import Operator
+from ccpi.optimisation.operators import Operator, LinearOperator
from ccpi.framework import BlockGeometry
+try:
+ from sirf import SIRF
+ from sirf.SIRF import DataContainer as SIRFDataContainer
+ has_sirf = True
+except ImportError as ie:
+ has_sirf = False
class BlockOperator(Operator):
r'''A Block matrix containing Operators
@@ -115,7 +126,23 @@ class BlockOperator(Operator):
return self.operators[index]
def norm(self, **kwargs):
- norm = [op.norm(**kwargs)**2 for op in self.operators]
+ '''Returns the norm of the BlockOperator
+
+ if the operator in the block do not have method norm defined, i.e. they are SIRF
+ AcquisitionModel's we use PowerMethod if applicable, otherwise we raise an Error
+ '''
+ norm = []
+ for op in self.operators:
+ if hasattr(op, 'norm'):
+ norm.append(op.norm(**kwargs) ** 2.)
+ else:
+ # use Power method
+ if op.is_linear():
+ norm.append(
+ LinearOperator.PowerMethod(op, 20)[0]
+ )
+ else:
+ raise TypeError('Operator {} does not have a norm method and is not linear'.format(op))
return numpy.sqrt(sum(norm))
def direct(self, x, out=None):
@@ -188,7 +215,8 @@ class BlockOperator(Operator):
prod += self.get_item(row, col).adjoint(x_b.get_item(row))
res.append(prod)
if self.shape[1]==1:
- return ImageData(*res)
+ # the output is a single DataContainer, so we can take it out
+ return res[0]
else:
return BlockDataContainer(*res, shape=shape)
else:
@@ -196,7 +224,8 @@ class BlockOperator(Operator):
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
- if issubclass(out.__class__, DataContainer):
+ if issubclass(out.__class__, DataContainer) or \
+ ( has_sirf and issubclass(out.__class__, SIRFDataContainer) ):
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=out)
@@ -206,7 +235,8 @@ class BlockOperator(Operator):
x_b.get_item(row),
out=out.get_item(col))
else:
- if issubclass(out.__class__, DataContainer):
+ if issubclass(out.__class__, DataContainer) or \
+ ( has_sirf and issubclass(out.__class__, SIRFDataContainer) ):
out += self.get_item(row,col).adjoint(
x_b.get_item(row))
else:
@@ -423,4 +453,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/operators/BlockScaledOperator.py b/Wrappers/Python/ccpi/optimisation/operators/BlockScaledOperator.py
index c23c23a..eeecee9 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/BlockScaledOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/BlockScaledOperator.py
@@ -18,6 +18,11 @@
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from numbers import Number
import numpy
from ccpi.optimisation.operators import ScaledOperator
@@ -82,4 +87,4 @@ class BlockScaledOperator(ScaledOperator):
@property
def T(self):
'''Return the transposed of self'''
- return type(self)(self.operator.T, self.scalar) \ No newline at end of file
+ return type(self)(self.operator.T, self.scalar)
diff --git a/Wrappers/Python/ccpi/optimisation/operators/FiniteDifferenceOperator.py b/Wrappers/Python/ccpi/optimisation/operators/FiniteDifferenceOperator.py
index 9b5ae24..3cc4309 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/FiniteDifferenceOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/FiniteDifferenceOperator.py
@@ -15,6 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.operators import LinearOperator
import numpy as np
@@ -66,8 +71,11 @@ class FiniteDiff(LinearOperator):
x_sz = len(x.shape)
if out is None:
- out = np.zeros_like(x_asarr)
+ res = x * 0
+ #out = np.zeros_like(x_asarr)
+ out = res.as_array()
else:
+ res = out
out = out.as_array()
out[:]=0
@@ -180,7 +188,9 @@ class FiniteDiff(LinearOperator):
raise NotImplementedError
# res = out #/self.voxel_size
- return type(x)(out)
+ #return type(x)(out)
+ res.fill(out)
+ return res
def adjoint(self, x, out=None):
@@ -189,8 +199,11 @@ class FiniteDiff(LinearOperator):
x_sz = len(x.shape)
if out is None:
- out = np.zeros_like(x_asarr)
+ #out = np.zeros_like(x_asarr)
+ res = x * 0
+ out = res.as_array()
else:
+ res = out
out = out.as_array()
out[:]=0
@@ -319,7 +332,9 @@ class FiniteDiff(LinearOperator):
raise NotImplementedError
out *= -1 #/self.voxel_size
- return type(x)(out)
+ #return type(x)(out)
+ res.fill(out)
+ return res
def range_geometry(self):
@@ -389,4 +404,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py b/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py
index baebc61..3c32a93 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py
@@ -15,6 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.operators import Operator, LinearOperator, ScaledOperator
from ccpi.framework import ImageData, ImageGeometry, BlockGeometry, BlockDataContainer
import numpy
diff --git a/Wrappers/Python/ccpi/optimisation/operators/IdentityOperator.py b/Wrappers/Python/ccpi/optimisation/operators/IdentityOperator.py
index d8f86a4..e95234b 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/IdentityOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/IdentityOperator.py
@@ -16,6 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.operators import LinearOperator
import scipy.sparse as sp
import numpy as np
@@ -113,4 +118,4 @@ if __name__ == '__main__':
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/operators/LinearOperator.py b/Wrappers/Python/ccpi/optimisation/operators/LinearOperator.py
index 8514699..f4d97b8 100755
--- a/Wrappers/Python/ccpi/optimisation/operators/LinearOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/LinearOperator.py
@@ -16,6 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.optimisation.operators import Operator
import numpy
@@ -39,7 +44,7 @@ class LinearOperator(Operator):
# Initialise random
if x_init is None:
- x0 = operator.domain_geometry().allocate(type(operator.domain_geometry()).RANDOM_INT)
+ x0 = operator.domain_geometry().allocate('random')
else:
x0 = x_init.copy()
@@ -51,7 +56,11 @@ class LinearOperator(Operator):
operator.direct(x0,out=y_tmp)
operator.adjoint(y_tmp,out=x1)
x1norm = x1.norm()
- s[it] = x1.dot(x0) / x0.squared_norm()
+ if hasattr(x0, 'squared_norm'):
+ s[it] = x1.dot(x0) / x0.squared_norm()
+ else:
+ x0norm = x0.norm()
+ s[it] = x1.dot(x0) / (x0norm * x0norm)
x1.multiply((1.0/x1norm), out=x0)
return numpy.sqrt(s[-1]), numpy.sqrt(s), x0
@@ -62,4 +71,41 @@ class LinearOperator(Operator):
s1, sall, svec = LinearOperator.PowerMethod(self, iterations, x_init=x0)
return s1
-
+ @staticmethod
+ def dot_test(operator, domain_init=None, range_init=None, verbose=False):
+ '''Does a dot linearity test on the operator
+
+ Evaluates if the following equivalence holds
+
+ :math: ..
+
+ Ax\times y = y \times A^Tx
+
+ :param operator: operator to test
+ :param range_init: optional initialisation container in the operator range
+ :param domain_init: optional initialisation container in the operator domain
+ :returns: boolean, True if the test is passed.
+ '''
+ if range_init is None:
+ y = operator.range_geometry().allocate('random_int')
+ else:
+ y = range_init
+ if domain_init is None:
+ x = operator.domain_geometry().allocate('random_int')
+ else:
+ x = domain_init
+
+ fx = operator.direct(x)
+ by = operator.adjoint(y)
+ a = fx.dot(y)
+ b = by.dot(x)
+ if verbose:
+ print ('Left hand side {}, \nRight hand side {}'.format(a, b))
+ try:
+ numpy.testing.assert_almost_equal(abs((a-b)/a), 0, decimal=4)
+ return True
+ except AssertionError as ae:
+ print (ae)
+ return False
+
+
diff --git a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py
index bc3312d..7d18ea1 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py
@@ -15,6 +15,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
import numpy
from scipy.sparse.linalg import svds
from ccpi.framework import VectorGeometry
diff --git a/Wrappers/Python/ccpi/optimisation/operators/Operator.py b/Wrappers/Python/ccpi/optimisation/operators/Operator.py
index 2678bf2..87059e6 100755
--- a/Wrappers/Python/ccpi/optimisation/operators/Operator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/Operator.py
@@ -15,6 +15,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+
from ccpi.optimisation.operators.ScaledOperator import ScaledOperator
class Operator(object):
diff --git a/Wrappers/Python/ccpi/optimisation/operators/ShrinkageOperator.py b/Wrappers/Python/ccpi/optimisation/operators/ShrinkageOperator.py
index c1f7ca4..9239d90 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/ShrinkageOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/ShrinkageOperator.py
@@ -15,6 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
from ccpi.framework import DataContainer
@@ -31,4 +36,4 @@ class ShrinkageOperator():
def __call__(self, x, tau, out=None):
return x.sign() * (x.abs() - tau).maximum(0)
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/operators/SparseFiniteDiff.py b/Wrappers/Python/ccpi/optimisation/operators/SparseFiniteDiff.py
index 91d5ca9..698a993 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/SparseFiniteDiff.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/SparseFiniteDiff.py
@@ -15,6 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
import scipy.sparse as sp
import numpy as np
@@ -155,4 +160,4 @@ if __name__ == '__main__':
u_per_sp_adjoint3D = sFD_per3D.adjoint(arr3D)
np.testing.assert_array_almost_equal(u_per_adjoint3D.as_array(), u_per_sp_adjoint3D.as_array(), decimal=4)
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/operators/SymmetrizedGradientOperator.py b/Wrappers/Python/ccpi/optimisation/operators/SymmetrizedGradientOperator.py
index 92f8f90..8d14cf8 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/SymmetrizedGradientOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/SymmetrizedGradientOperator.py
@@ -15,9 +15,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
-from ccpi.optimisation.operators import Gradient, Operator, LinearOperator, ScaledOperator
-from ccpi.framework import ImageData, ImageGeometry, BlockGeometry, BlockDataContainer
+
+from ccpi.optimisation.operators import Gradient, Operator, LinearOperator,\
+ ScaledOperator
+from ccpi.framework import ImageData, ImageGeometry, BlockGeometry, \
+ BlockDataContainer
import numpy
from ccpi.optimisation.operators import FiniteDiff
diff --git a/Wrappers/Python/ccpi/optimisation/operators/ZeroOperator.py b/Wrappers/Python/ccpi/optimisation/operators/ZeroOperator.py
index 5f1de30..c37e15e 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/ZeroOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/ZeroOperator.py
@@ -15,6 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
import numpy as np
from ccpi.framework import ImageData
@@ -84,4 +89,4 @@ class ZeroOperator(LinearOperator):
'''Returns domain_geometry of ZeroOperator'''
- return self.gm_range \ No newline at end of file
+ return self.gm_range
diff --git a/Wrappers/Python/test/test_NexusReader.py b/Wrappers/Python/test/test_NexusReader.py
index 71a05c2..992ce4f 100755
--- a/Wrappers/Python/test/test_NexusReader.py
+++ b/Wrappers/Python/test/test_NexusReader.py
@@ -17,7 +17,11 @@
# limitations under the License.
import unittest
-import wget
+has_wget = True
+try:
+ import wget
+except ImportError as ie:
+ has_wget = False
import os
from ccpi.io.reader import NexusReader
import numpy
@@ -26,80 +30,85 @@ import numpy
class TestNexusReader(unittest.TestCase):
def setUp(self):
- wget.download('https://github.com/DiamondLightSource/Savu/raw/master/test_data/data/24737_fd.nxs')
- self.filename = '24737_fd.nxs'
+ if has_wget:
+ wget.download('https://github.com/DiamondLightSource/Savu/raw/master/test_data/data/24737_fd.nxs')
+ self.filename = '24737_fd.nxs'
def tearDown(self):
- os.remove(self.filename)
+ if has_wget:
+ os.remove(self.filename)
def testAll(self):
- # def testGetDimensions(self):
- nr = NexusReader(self.filename)
- self.assertEqual(nr.get_sinogram_dimensions(), (135, 91, 160), "Sinogram dimensions are not correct")
-
- # def testGetProjectionDimensions(self):
- nr = NexusReader(self.filename)
- self.assertEqual(nr.get_projection_dimensions(), (91,135,160), "Projection dimensions are not correct")
-
- # def testLoadProjectionWithoutDimensions(self):
- nr = NexusReader(self.filename)
- projections = nr.load_projection()
- self.assertEqual(projections.shape, (91,135,160), "Loaded projection data dimensions are not correct")
-
- # def testLoadProjectionWithDimensions(self):
- nr = NexusReader(self.filename)
- projections = nr.load_projection((slice(0,1), slice(0,135), slice(0,160)))
- self.assertEqual(projections.shape, (1,135,160), "Loaded projection data dimensions are not correct")
-
- # def testLoadProjectionCompareSingle(self):
- nr = NexusReader(self.filename)
- projections_full = nr.load_projection()
- projections_part = nr.load_projection((slice(0,1), slice(0,135), slice(0,160)))
- numpy.testing.assert_array_equal(projections_part, projections_full[0:1,:,:])
-
- # def testLoadProjectionCompareMulti(self):
- nr = NexusReader(self.filename)
- projections_full = nr.load_projection()
- projections_part = nr.load_projection((slice(0,3), slice(0,135), slice(0,160)))
- numpy.testing.assert_array_equal(projections_part, projections_full[0:3,:,:])
-
- # def testLoadProjectionCompareRandom(self):
- nr = NexusReader(self.filename)
- projections_full = nr.load_projection()
- projections_part = nr.load_projection((slice(1,8), slice(5,10), slice(8,20)))
- numpy.testing.assert_array_equal(projections_part, projections_full[1:8,5:10,8:20])
-
- # def testLoadProjectionCompareFull(self):
- nr = NexusReader(self.filename)
- projections_full = nr.load_projection()
- projections_part = nr.load_projection((slice(None,None), slice(None,None), slice(None,None)))
- numpy.testing.assert_array_equal(projections_part, projections_full[:,:,:])
-
- # def testLoadFlatCompareFull(self):
- nr = NexusReader(self.filename)
- flats_full = nr.load_flat()
- flats_part = nr.load_flat((slice(None,None), slice(None,None), slice(None,None)))
- numpy.testing.assert_array_equal(flats_part, flats_full[:,:,:])
-
- # def testLoadDarkCompareFull(self):
- nr = NexusReader(self.filename)
- darks_full = nr.load_dark()
- darks_part = nr.load_dark((slice(None,None), slice(None,None), slice(None,None)))
- numpy.testing.assert_array_equal(darks_part, darks_full[:,:,:])
-
- # def testProjectionAngles(self):
- nr = NexusReader(self.filename)
- angles = nr.get_projection_angles()
- self.assertEqual(angles.shape, (91,), "Loaded projection number of angles are not correct")
-
- # def test_get_acquisition_data_subset(self):
- nr = NexusReader(self.filename)
- key = nr.get_image_keys()
- sl = nr.get_acquisition_data_subset(0,10)
- data = nr.get_acquisition_data().subset(['vertical','horizontal'])
-
- self.assertTrue(sl.shape , (10,data.shape[1]))
+ if has_wget:
+ # def testGetDimensions(self):
+ nr = NexusReader(self.filename)
+ self.assertEqual(nr.get_sinogram_dimensions(), (135, 91, 160), "Sinogram dimensions are not correct")
+ # def testGetProjectionDimensions(self):
+ nr = NexusReader(self.filename)
+ self.assertEqual(nr.get_projection_dimensions(), (91,135,160), "Projection dimensions are not correct")
+
+ # def testLoadProjectionWithoutDimensions(self):
+ nr = NexusReader(self.filename)
+ projections = nr.load_projection()
+ self.assertEqual(projections.shape, (91,135,160), "Loaded projection data dimensions are not correct")
+
+ # def testLoadProjectionWithDimensions(self):
+ nr = NexusReader(self.filename)
+ projections = nr.load_projection((slice(0,1), slice(0,135), slice(0,160)))
+ self.assertEqual(projections.shape, (1,135,160), "Loaded projection data dimensions are not correct")
+
+ # def testLoadProjectionCompareSingle(self):
+ nr = NexusReader(self.filename)
+ projections_full = nr.load_projection()
+ projections_part = nr.load_projection((slice(0,1), slice(0,135), slice(0,160)))
+ numpy.testing.assert_array_equal(projections_part, projections_full[0:1,:,:])
+
+ # def testLoadProjectionCompareMulti(self):
+ nr = NexusReader(self.filename)
+ projections_full = nr.load_projection()
+ projections_part = nr.load_projection((slice(0,3), slice(0,135), slice(0,160)))
+ numpy.testing.assert_array_equal(projections_part, projections_full[0:3,:,:])
+
+ # def testLoadProjectionCompareRandom(self):
+ nr = NexusReader(self.filename)
+ projections_full = nr.load_projection()
+ projections_part = nr.load_projection((slice(1,8), slice(5,10), slice(8,20)))
+ numpy.testing.assert_array_equal(projections_part, projections_full[1:8,5:10,8:20])
+
+ # def testLoadProjectionCompareFull(self):
+ nr = NexusReader(self.filename)
+ projections_full = nr.load_projection()
+ projections_part = nr.load_projection((slice(None,None), slice(None,None), slice(None,None)))
+ numpy.testing.assert_array_equal(projections_part, projections_full[:,:,:])
+
+ # def testLoadFlatCompareFull(self):
+ nr = NexusReader(self.filename)
+ flats_full = nr.load_flat()
+ flats_part = nr.load_flat((slice(None,None), slice(None,None), slice(None,None)))
+ numpy.testing.assert_array_equal(flats_part, flats_full[:,:,:])
+
+ # def testLoadDarkCompareFull(self):
+ nr = NexusReader(self.filename)
+ darks_full = nr.load_dark()
+ darks_part = nr.load_dark((slice(None,None), slice(None,None), slice(None,None)))
+ numpy.testing.assert_array_equal(darks_part, darks_full[:,:,:])
+
+ # def testProjectionAngles(self):
+ nr = NexusReader(self.filename)
+ angles = nr.get_projection_angles()
+ self.assertEqual(angles.shape, (91,), "Loaded projection number of angles are not correct")
+
+ # def test_get_acquisition_data_subset(self):
+ nr = NexusReader(self.filename)
+ key = nr.get_image_keys()
+ sl = nr.get_acquisition_data_subset(0,10)
+ data = nr.get_acquisition_data().subset(['vertical','horizontal'])
+
+ self.assertTrue(sl.shape , (10,data.shape[1]))
+ else:
+ # skips all tests if module wget is not present
+ self.assertFalse(has_wget)
if __name__ == '__main__':