summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python/wip/demo_astra_sophiabeads.py
blob: dce3d91e090a4a4e5271b3ad9e18c8e5f8726b30 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

# This demo shows how to load a Nikon XTek micro-CT data set and reconstruct
# the central slice using the CGLS and FISTA methods. The SophiaBeads dataset 
# with 64 projections is used as test data and can be obtained from here:
# https://zenodo.org/record/16474
# The filename with full path to the .xtekct file should be given as string 
# input to NikonDataReader to  load in the data.

# Do all imports
import numpy as np
import matplotlib.pyplot as plt
from ccpi.io import NikonDataReader
from ccpi.framework import ImageGeometry, AcquisitionGeometry, AcquisitionData, ImageData
from ccpi.astra.operators import AstraProjectorSimple
from ccpi.optimisation.algorithms import FISTA, CGLS
from ccpi.optimisation.functions import Norm2Sq, L1Norm

# Set up reader object and read in central slice the data
datareader = NikonDataReader(xtek_file="REPLACE_THIS_BY_PATH_TO_DATASET/SophiaBeads_64_averaged.xtekct",
                             roi=[(1000,1001),(0,2000)])
data = datareader.load_projections()

# Extract central slice, scale and negative-log transform
sino = -np.log(data.as_array()[:,0,:]/60000.0)

# Apply centering correction by zero padding, amount found manually
cor_pad = 30
sino_pad = np.zeros((sino.shape[0],sino.shape[1]+cor_pad))
sino_pad[:,cor_pad:] = sino

# Extract AcquisitionGeometry for central slice for 2D fanbeam reconstruction
ag2d = AcquisitionGeometry('cone',
                          '2D',
                          angles=-np.pi/180*data.geometry.angles,
                          pixel_num_h=data.geometry.pixel_num_h + cor_pad,
                          pixel_size_h=data.geometry.pixel_size_h,
                          dist_source_center=-data.geometry.dist_source_center, 
                          dist_center_detector=data.geometry.dist_center_detector)

# Set up AcquisitionData object for central slice 2D fanbeam
data2d = AcquisitionData(sino_pad,geometry=ag2d)

# Choose the number of voxels to reconstruct onto as number of detector pixels
N = data.geometry.pixel_num_h

# Geometric magnification
mag = (np.abs(data.geometry.dist_center_detector) + \
      np.abs(data.geometry.dist_source_center)) / \
      np.abs(data.geometry.dist_source_center)

# Voxel size is detector pixel size divided by mag
voxel_size_h = data.geometry.pixel_size_h / mag

# Construct the appropriate ImageGeometry
ig2d = ImageGeometry(voxel_num_x=N,
                   voxel_num_y=N,
                   voxel_size_x=voxel_size_h, 
                   voxel_size_y=voxel_size_h)

# Set up the Projector (AcquisitionModel) using ASTRA on GPU
Aop = AstraProjectorSimple(ig2d, ag2d,"gpu")

# Set initial guess for CGLS reconstruction
x_init = ImageData(geometry=ig2d)

# Set tolerance and number of iterations for reconstruction algorithms.
opt = {'tol': 1e-4, 'iter': 50}

# First a CGLS reconstruction can be done:
CGLS_alg = CGLS()
CGLS_alg.set_up(x_init, Aop, data2d)
CGLS_alg.max_iteration = 2000
CGLS_alg.run(opt['iter'])

x_CGLS = CGLS_alg.get_output()

plt.figure()
plt.imshow(x_CGLS.array)
plt.title('CGLS')
plt.colorbar()
plt.show()

plt.figure()
plt.semilogy(CGLS_alg.objective)
plt.title('CGLS criterion')
plt.show()

# CGLS solves the simple least-squares problem. The same problem can be solved 
# by FISTA by setting up explicitly a least squares function object and using 
# no regularisation:

# Create least squares object instance with projector, test data and a constant 
# coefficient of 0.5:
f = Norm2Sq(Aop,data2d)

# Run FISTA for least squares without constraints
FISTA_alg = FISTA()
FISTA_alg.set_up(x_init=x_init, f=f, opt=opt)
FISTA_alg.max_iteration = 2000
FISTA_alg.run(opt['iter'])
x_FISTA = FISTA_alg.get_output()

plt.figure()
plt.imshow(x_FISTA.array)
plt.title('FISTA Least squares reconstruction')
plt.colorbar()
plt.show()

plt.figure()
plt.semilogy(FISTA_alg.objective)
plt.title('FISTA Least squares criterion')
plt.show()

# Create 1-norm function object
lam = 1.0
g0 = lam * L1Norm()

# Run FISTA for least squares plus 1-norm function.
FISTA_alg1 = FISTA()
FISTA_alg1.set_up(x_init=x_init, f=f, g=g0, opt=opt)
FISTA_alg1.max_iteration = 2000
FISTA_alg1.run(opt['iter'])
x_FISTA1 = FISTA_alg1.get_output()

plt.figure()
plt.imshow(x_FISTA1.array)
plt.title('FISTA LS+L1Norm reconstruction')
plt.colorbar()
plt.show()

plt.figure()
plt.semilogy(FISTA_alg1.objective)
plt.title('FISTA LS+L1norm criterion')
plt.show()