summaryrefslogtreecommitdiffstats
path: root/matlab/algorithms/DART/MaskingGPU.m
blob: b344dfd44454a7f371ea14c38a93e3f5610f2f54 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
%--------------------------------------------------------------------------
% This file is part of the ASTRA Toolbox
%
% Copyright: 2010-2018, imec Vision Lab, University of Antwerp
%            2014-2018, CWI, Amsterdam
% License: Open Source under GPLv3
% Contact: astra@astra-toolbox.com
% Website: http://www.astra-toolbox.com/
%--------------------------------------------------------------------------

classdef MaskingGPU < matlab.mixin.Copyable

	% Policy class for masking for DART with GPU accelerated code (deprecated).
	
	%----------------------------------------------------------------------
	properties (Access=public)
		
		radius			= 1;			% SETTING: Radius of masking kernel.
		conn			= 8;			% SETTING: Connectivity window. For 2D: 4 or 8.  For 3D: 6 or 26.
		edge_threshold	= 1;			% SETTING: Number of pixels in the window that should be different.
		gpu_core		= 0;			% SETTING:
		random			= 0.1;			% SETTING: Percentage of random points.  Between 0 and 1.
		
	end
	
	%----------------------------------------------------------------------
	methods (Access=public)
		
		%------------------------------------------------------------------
		function settings = getsettings(this)
			% Returns a structure containing all settings of this object.
			% >> settings = DART.masking.getsettings();				
			settings.radius				= this.radius;
			settings.conn				= this.conn;
			settings.edge_threshold		= this.edge_threshold;
			settings.random				= this.random;
		end
		
		%------------------------------------------------------------------
		function Mask = apply(this, ~, V_in)
			% Applies masking.
			% >> Mask = DART.segmentation.apply(DART, V_in);	
			
			% 2D, one slice
			if size(V_in,3) == 1
				Mask = this.apply_2D(V_in);
						
			% 3D, slice by slice
			elseif this.conn == 4 || this.conn == 8
				Mask = zeros(size(V_in));
				for slice = 1:size(V_in,3)
					Mask(:,:,slice) = this.apply_2D(V_in(:,:,slice)); 
				end
			
			% 3D, full
			else
				error('Full 3D masking on GPU not implemented.')
			end
			
		end
		
	end
		
	%----------------------------------------------------------------------
	methods (Access=protected)
		
		%------------------------------------------------------------------
		function Mask = apply_2D(this, S)
		
			vol_geom = astra_create_vol_geom(size(S));
			data_id = astra_mex_data2d('create', '-vol', vol_geom, S);
			mask_id = astra_mex_data2d('create', '-vol', vol_geom, 0);

			cfg = astra_struct('DARTMASK_CUDA');
			cfg.SegmentationDataId = data_id;
			cfg.MaskDataId = mask_id;
			cfg.option.GPUindex = this.gpu_core;
			%cfg.option.Connectivity = this.conn;
			
			alg_id = astra_mex_algorithm('create',cfg);	
			astra_mex_algorithm('iterate',alg_id,1);
			Mask = astra_mex_data2d('get', mask_id);
		
			astra_mex_algorithm('delete', alg_id);
			astra_mex_data2d('delete', data_id, mask_id);
			
		end	
	end


	
end