summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2020-06-23 15:19:06 +0200
committerWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2020-07-02 15:54:12 +0200
commit492c0211608fa756ba6642ff7ae3b479765a955b (patch)
tree325b34d12f5969f7b90d6eaacb4d023e2a386213
parentecfb65a05b8ed5171ad65173581d5fe328926995 (diff)
downloadastra-492c0211608fa756ba6642ff7ae3b479765a955b.tar.gz
astra-492c0211608fa756ba6642ff7ae3b479765a955b.tar.bz2
astra-492c0211608fa756ba6642ff7ae3b479765a955b.tar.xz
astra-492c0211608fa756ba6642ff7ae3b479765a955b.zip
Check numpy array type
-rw-r--r--python/astra/data2d.py7
-rw-r--r--python/astra/data3d.py7
-rw-r--r--python/astra/experimental.pyx4
-rw-r--r--python/astra/pythonutils.py15
-rw-r--r--python/astra/utils.pyx8
5 files changed, 25 insertions, 16 deletions
diff --git a/python/astra/data2d.py b/python/astra/data2d.py
index 188ff69..6ab458f 100644
--- a/python/astra/data2d.py
+++ b/python/astra/data2d.py
@@ -65,12 +65,7 @@ def link(datatype, geometry, data):
:returns: :class:`int` -- the ID of the constructed object.
"""
- if not isinstance(data,np.ndarray):
- raise ValueError("Input should be a numpy array")
- if not data.dtype==np.float32:
- raise ValueError("Numpy array should be float32")
- if not (data.flags['C_CONTIGUOUS'] and data.flags['ALIGNED']):
- raise ValueError("Numpy array should be C_CONTIGUOUS and ALIGNED")
+ checkArrayForLink(data)
return d.create(datatype,geometry,data,True)
def store(i, data):
diff --git a/python/astra/data3d.py b/python/astra/data3d.py
index b0d54b2..3eea0e3 100644
--- a/python/astra/data3d.py
+++ b/python/astra/data3d.py
@@ -26,7 +26,7 @@
from . import data3d_c as d
import numpy as np
-from .pythonutils import GPULink
+from .pythonutils import GPULink, checkArrayForLink
def create(datatype,geometry,data=None):
"""Create a 3D object.
@@ -57,10 +57,7 @@ def link(datatype, geometry, data):
if not isinstance(data,np.ndarray) and not isinstance(data,GPULink):
raise TypeError("Input should be a numpy ndarray or GPULink object")
if isinstance(data, np.ndarray):
- if data.dtype != np.float32:
- raise ValueError("Numpy array should be float32")
- if not (data.flags['C_CONTIGUOUS'] and data.flags['ALIGNED']):
- raise ValueError("Numpy array should be C_CONTIGUOUS and ALIGNED")
+ checkArrayForLink(data)
return d.create(datatype,geometry,data,True)
diff --git a/python/astra/experimental.pyx b/python/astra/experimental.pyx
index 25ecb24..c76fcbe 100644
--- a/python/astra/experimental.pyx
+++ b/python/astra/experimental.pyx
@@ -168,9 +168,9 @@ IF HAVE_CUDA==True:
:param projector_id: A 3D projector object handle
:type datatype: :class:`int`
- :param vol: The input data, as either a numpy array, or a GPULink object
+ :param vol: The pre-allocated output data, as either a numpy array, or a GPULink object
:type datatype: :class:`numpy.ndarray` or :class:`astra.data3d.GPULink`
- :param proj: The pre-allocated output data, either numpy array or GPULink
+ :param proj: The input data, either numpy array or GPULink
:type datatype: :class:`numpy.ndarray` or :class:`astra.data3d.GPULink`
"""
direct_FPBP3D(projector_id, vol, proj, "BP")
diff --git a/python/astra/pythonutils.py b/python/astra/pythonutils.py
index 715df30..ef49f97 100644
--- a/python/astra/pythonutils.py
+++ b/python/astra/pythonutils.py
@@ -29,6 +29,8 @@
"""
+import numpy as np
+
def geom_size(geom, dim=None):
"""Returns the size of a volume or sinogram, based on the projection or volume geometry.
@@ -62,6 +64,19 @@ def geom_size(geom, dim=None):
return s
+def checkArrayForLink(data):
+ """Check if a numpy array is suitable for direct usage (contiguous, etc.)
+
+ This function raises an exception if not.
+ """
+
+ if not isinstance(data, np.ndarray):
+ raise ValueError("Numpy array should be numpy.ndarray")
+ if data.dtype != np.float32:
+ raise ValueError("Numpy array should be float32")
+ if not (data.flags['C_CONTIGUOUS'] and data.flags['ALIGNED']):
+ raise ValueError("Numpy array should be C_CONTIGUOUS and ALIGNED")
+
class GPULink(object):
"""Utility class for astra.data3d.link with a CUDA pointer
diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx
index 3b6e3ff..12fc38c 100644
--- a/python/astra/utils.pyx
+++ b/python/astra/utils.pyx
@@ -45,7 +45,7 @@ from .PyXMLDocument cimport XMLDocument
from .PyXMLDocument cimport XMLNode
from .PyIncludes cimport *
-from .pythonutils import GPULink
+from .pythonutils import GPULink, checkArrayForLink
cdef extern from "CFloat32CustomPython.h":
cdef cppclass CFloat32CustomPython:
@@ -252,9 +252,10 @@ cdef CFloat32VolumeData3D* linkVolFromGeometry(CVolumeGeometry3D *pGeometry, dat
data_shape = (data.z, data.y, data.x)
if geom_shape != data_shape:
raise ValueError(
- "The dimensions of the data do not match those specified in the geometry.".format(data_shape, geom_shape))
+ "The dimensions of the data do not match those specified in the geometry: {} != {}".format(data_shape, geom_shape))
if isinstance(data, np.ndarray):
+ checkArrayForLink(data)
pCustom = <CFloat32CustomMemory*> new CFloat32CustomPython(data)
pDataObject3D = new CFloat32VolumeData3DMemory(pGeometry, pCustom)
elif isinstance(data, GPULink):
@@ -276,9 +277,10 @@ cdef CFloat32ProjectionData3D* linkProjFromGeometry(CProjectionGeometry3D *pGeom
data_shape = (data.z, data.y, data.x)
if geom_shape != data_shape:
raise ValueError(
- "The dimensions of the data do not match those specified in the geometry.".format(data_shape, geom_shape))
+ "The dimensions of the data do not match those specified in the geometry: {} != {}".format(data_shape, geom_shape))
if isinstance(data, np.ndarray):
+ checkArrayForLink(data)
pCustom = <CFloat32CustomMemory*> new CFloat32CustomPython(data)
pDataObject3D = new CFloat32ProjectionData3DMemory(pGeometry, pCustom)
elif isinstance(data, GPULink):