summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python/wip/demo_nexus.py
blob: ec30fd5a19b74a50a3ef97c5aeca4ca0cd71edac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

# This script demonstrates how to load a parallel beam data set in Nexus 
# format, apply dark and flat field correction and reconstruct using the
# modular optimisation framework.
# 
# The data set is available from
# https://github.com/DiamondLightSource/Savu/blob/master/test_data/data/24737_fd.nxs
# and should be downloaded to a local directory to be specified below.

# All own imports
from ccpi.framework import ImageData, AcquisitionData, ImageGeometry, AcquisitionGeometry
from ccpi.optimisation.algs import FISTA, FBPD, CGLS
from ccpi.optimisation.funcs import Norm2sq, Norm1
from ccpi.plugins.ops import CCPiProjectorSimple
from ccpi.processors import Normalizer, CenterOfRotationFinder 
from ccpi.plugins.processors import AcquisitionDataPadder
from ccpi.io.reader import NexusReader

# All external imports
import numpy
import matplotlib.pyplot as plt
import os

# Define utility function to average over flat and dark images.
def avg_img(image):
    shape = list(numpy.shape(image))
    l = shape.pop(0)
    avg = numpy.zeros(shape)
    for i in range(l):
        avg += image[i] / l
    return avg
    
# Set up a reader object pointing to the Nexus data set. Revise path as needed.
reader = NexusReader(os.path.join(".." ,".." ,".." , "..", "CCPi-ReconstructionFramework","data" , "24737_fd.nxs" ))

# Read and print the dimensions of the raw projections
dims = reader.get_projection_dimensions()
print (dims)

# Load and average all flat and dark images in preparation for normalising data.
flat = avg_img(reader.load_flat())
dark = avg_img(reader.load_dark())

# Set up normaliser object for normalising data by flat and dark images.
norm = Normalizer(flat_field=flat, dark_field=dark)

# Load the raw projections and pass as input to the normaliser.
norm.set_input(reader.get_acquisition_data())

# Set up CenterOfRotationFinder object to center data.
cor = CenterOfRotationFinder()

# Set the output of the normaliser as the input and execute to determine center.
cor.set_input(norm.get_output())
center_of_rotation = cor.get_output()

# Set up AcquisitionDataPadder to pad data for centering using the computed 
# center, set the output of the normaliser as input and execute to produce
# padded/centered data.
padder = AcquisitionDataPadder(center_of_rotation=center_of_rotation)
padder.set_input(norm.get_output())
padded_data = padder.get_output()

# Create Acquisition and Image Geometries for setting up projector.
ag = padded_data.geometry
ig = ImageGeometry(voxel_num_x=ag.pixel_num_h,
                   voxel_num_y=ag.pixel_num_h, 
                   voxel_num_z=ag.pixel_num_v)

# Define the projector object
print ("Define projector")
Cop = CCPiProjectorSimple(ig, ag)

# Create least squares object instance with projector and data.
print ("Create least squares object instance with projector and data.")
f = Norm2sq(Cop,padded_data,c=0.5)

# Set initial guess
print ("Initial guess")
x_init = ImageData(geometry=ig, dimension_labels=['horizontal_x','horizontal_y','vertical'])
        
# Run FISTA reconstruction for least squares without regularization
print ("Run FISTA for least squares")
opt = {'tol': 1e-4, 'iter': 10}
x_fista0, it0, timing0, criter0 = FISTA(x_init, f, None, opt=opt)

plt.imshow(x_fista0.subset(horizontal_x=80).array)
plt.title('FISTA LS')
plt.show()

# Set up 1-norm function for FISTA least squares plus 1-norm regularisation
print ("Run FISTA for least squares plus 1-norm regularisation")
lam = 0.1
g0 = Norm1(lam)

# Run FISTA for least squares plus 1-norm function.
x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g0,opt=opt)

plt.imshow(x_fista0.subset(horizontal_x=80).array)
plt.title('FISTA LS+1')
plt.show()

# Run FBPD=Forward Backward Primal Dual method on least squares plus 1-norm
print ("Run FBPD for least squares plus 1-norm regularisation")
x_fbpd1, it_fbpd1, timing_fbpd1, criter_fbpd1 = FBPD(x_init,None,f,g0,opt=opt)

plt.imshow(x_fbpd1.subset(horizontal_x=80).array)
plt.title('FBPD LS+1')
plt.show()

# Run CGLS, which should agree with the FISTA least squares
print ("Run CGLS for least squares")
x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Cop, padded_data, opt=opt)
plt.imshow(x_CGLS.subset(horizontal_x=80).array)
plt.title('CGLS')
plt.show()

# Display all reconstructions and decay of objective function
cols = 4
rows = 1
current = 1
fig = plt.figure()

current = current 
a=fig.add_subplot(rows,cols,current)
a.set_title('FISTA LS')
imgplot = plt.imshow(x_fista0.subset(horizontal_x=80).as_array())

current = current + 1
a=fig.add_subplot(rows,cols,current)
a.set_title('FISTA LS+1')
imgplot = plt.imshow(x_fista1.subset(horizontal_x=80).as_array())

current = current + 1
a=fig.add_subplot(rows,cols,current)
a.set_title('FBPD LS+1')
imgplot = plt.imshow(x_fbpd1.subset(horizontal_x=80).as_array())

current = current + 1
a=fig.add_subplot(rows,cols,current)
a.set_title('CGLS')
imgplot = plt.imshow(x_CGLS.subset(horizontal_x=80).as_array())

plt.show()

fig = plt.figure()
b=fig.add_subplot(1,1,1)
b.set_title('criteria')
imgplot = plt.loglog(criter0 , label='FISTA LS')
imgplot = plt.loglog(criter1 , label='FISTA LS+1')
imgplot = plt.loglog(criter_fbpd1, label='FBPD LS+1')
imgplot = plt.loglog(criter_CGLS, label='CGLS')
b.legend(loc='right')
plt.show()