summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py39
1 files changed, 15 insertions, 24 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 9e40c95..fa1d8d8 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -39,40 +39,31 @@ class FISTA(Algorithm):
'''initialisation can be done at creation time if all
proper variables are passed or later with set_up'''
super(FISTA, self).__init__()
- self.f = kwargs.get('f', None)
- self.g = kwargs.get('g', ZeroFunction())
- self.x_init = kwargs.get('x_init',None)
- self.invL = None
- self.t_old = 1
- if self.x_init is not None and \
- self.f is not None and self.g is not None:
- print ("FISTA set_up called from creator")
- self.set_up(self.x_init, self.f, self.g)
+ f = kwargs.get('f', None)
+ g = kwargs.get('g', ZeroFunction())
+ x_init = kwargs.get('x_init', None)
+
+ if x_init is not None and f is not None:
+ print(self.__class__.__name__ , "set_up called from creator")
+ self.set_up(x_init=x_init, f=f, g=g)
+
+ def set_up(self, x_init, f, g=ZeroFunction()):
-
- def set_up(self, x_init, f, g, opt=None, **kwargs):
-
- self.f = f
- self.g = g
-
- # algorithmic parameters
- if opt is None:
- opt = {'tol': 1e-4}
-
self.y = x_init.copy()
self.x_old = x_init.copy()
self.x = x_init.copy()
- self.u = x_init.copy()
+ self.u = x_init.copy()
+ self.f = f
+ self.g = g
self.invL = 1/f.L
-
- self.t_old = 1
+ self.t = 1
self.update_objective()
self.configured = True
def update(self):
-
+ self.t_old = self.t
self.f.gradient(self.y, out=self.u)
self.u.__imul__( -self.invL )
self.u.__iadd__( self.y )
@@ -87,7 +78,7 @@ class FISTA(Algorithm):
self.y.__iadd__( self.x )
self.x_old.fill(self.x)
- self.t_old = self.t
+
def update_objective(self):
self.loss.append( self.f(self.x) + self.g(self.x) )