From 51e677d2b8cd631c70d47cd4965150969d878fb2 Mon Sep 17 00:00:00 2001 From: Carl-Martin Pfeiler <carl-martin.pfeiler@asc.tuwien.ac.at> Date: Mon, 19 Aug 2019 14:44:55 +0200 Subject: [PATCH] Count FP iterations for MPS --- _setup/outputOptions.py | 4 ++++ .../_details/interfaces/computeInterface.py | 1 + .../_details/interfaces/parameterInterface.py | 10 ++++++++++ .../_details/interfaces/writeInterface.py | 16 ++++++++++++++++ integrators/_integrator.py | 10 ++++++++++ integrators/integrator.py | 7 +++++++ integrators/mps.py | 8 ++++++-- 7 files changed, 54 insertions(+), 2 deletions(-) diff --git a/_setup/outputOptions.py b/_setup/outputOptions.py index 2847b6b..7831e99 100644 --- a/_setup/outputOptions.py +++ b/_setup/outputOptions.py @@ -28,6 +28,7 @@ class OutputOptions: self.filenameUnprojectedExchangeEnergy = "unprojectedExchangeEnergy.dat" self.filenamePlotEnergies = "plotEnergies.eps" self.filenameIterationCount = "iterations.dat" + self.filenameFixedPointIterationCount = "fp_iterations.dat" self.filenameAdaptiveAxis = "adaptiveAxis.dat" self.columnWidth = 25 @@ -73,6 +74,9 @@ class OutputOptions: self.headerIterationCount = "columns are: t,".ljust(self.columnWidth-2) + " " \ + "gmres iter count,".ljust(self.columnWidth) + self.headerFixedPointIterationCount = "columns are: t,".ljust(self.columnWidth-2) + " " \ + + "fixed-point iter count,".ljust(self.columnWidth) + self.headerAdaptiveAxis = "columns are: t,".ljust(self.columnWidth-2) + " " \ + "adaptiveAxis,".ljust(self.columnWidth) + " " \ + "adaptiveGamma,".ljust(self.columnWidth) + " " \ diff --git a/integrators/_details/interfaces/computeInterface.py b/integrators/_details/interfaces/computeInterface.py index 11b27a1..6abd485 100644 --- a/integrators/_details/interfaces/computeInterface.py +++ b/integrators/_details/interfaces/computeInterface.py @@ -142,6 +142,7 @@ class ComputeInterface(_interface._Interface): self._ExistsMagnetization() self._ResetIterationCounter() + self._ResetFixedPointIterationCounter() #------------------------------------------------------------------------------# diff --git a/integrators/_details/interfaces/parameterInterface.py b/integrators/_details/interfaces/parameterInterface.py index 1638b05..f80f603 100644 --- a/integrators/_details/interfaces/parameterInterface.py +++ b/integrators/_details/interfaces/parameterInterface.py @@ -207,6 +207,16 @@ class ParameterInterface(_interface._Interface): self._callback = None +#------------------------------------------------------------------------------# + + + def CountFixedPointIterations(self, value=True): + if value: + self._fixedPointCallback = self._fixedPointItCounter + else: + self._fixedPointCallback = None + + #------------------------------------------------------------------------------# diff --git a/integrators/_details/interfaces/writeInterface.py b/integrators/_details/interfaces/writeInterface.py index 830af11..f848248 100644 --- a/integrators/_details/interfaces/writeInterface.py +++ b/integrators/_details/interfaces/writeInterface.py @@ -103,6 +103,7 @@ class WriteInterface(_interface._Interface): self._WriteMagnetization() self._WriteMaxwell() self._WriteIterations() + self._WriteFixedPointIterations() self._WriteAdaptiveAxis() @@ -298,6 +299,21 @@ class WriteInterface(_interface._Interface): , self.Time(), itc, nrwidth=self.output.columnWidth, header=headerItc) +#------------------------------------------------------------------------------# + + + def _WriteFixedPointIterations(self): + from commics._tools.output import LiveSaveValues + from commics.integrators._details import IterationCounter + if not isinstance(self._fixedPointCallback, IterationCounter): return + + ## iterations + headerFpItc = self.output.headerFixedPointIterationCount if self._counter == 0 else None + fpitc = [self._fixedPointCallback.GetCount()] + LiveSaveValues(self.output.foldername+"/"+self.output.filenameFixedPointIterationCount \ + , self.Time(), fpitc, nrwidth=self.output.columnWidth, header=headerFpItc) + + #------------------------------------------------------------------------------# diff --git a/integrators/_integrator.py b/integrators/_integrator.py index 6b1dd0e..6e93825 100644 --- a/integrators/_integrator.py +++ b/integrators/_integrator.py @@ -69,7 +69,9 @@ class _Integrator( \ self._A_instat = None self._callback = None + self._fixedPointCallback = None self._itCounter = _details.IterationCounter() + self._fixedPointItCounter = _details.IterationCounter() self._preconditioner = None self._guess = None self._solution = None @@ -139,7 +141,15 @@ class _Integrator( \ #------------------------------------------------------------------------------# + + def _ResetFixedPointIterationCounter(self): + if isinstance(self._fixedPointItCounter, _details.IterationCounter): + self._fixedPointItCounter.Reset() + +#------------------------------------------------------------------------------# + + def _SetNumThreads(self, numthreads): import os os.environ['OMP_NUM_THREADS'] = str(numthreads) diff --git a/integrators/integrator.py b/integrators/integrator.py index 643601d..90fa13c 100644 --- a/integrators/integrator.py +++ b/integrators/integrator.py @@ -265,6 +265,13 @@ class Integrator: self._integrator.CountIterations(value) +#------------------------------------------------------------------------------# + + + def CountFixedPointIterations(self, value=True): + self._integrator.CountFixedPointIterations(value) + + #------------------------------------------------------------------------------# diff --git a/integrators/mps.py b/integrators/mps.py index b404849..b3a62cc 100644 --- a/integrators/mps.py +++ b/integrators/mps.py @@ -316,6 +316,7 @@ class MPS(_integrator._Integrator, \ import scipy.sparse.linalg from ngsolve import Integrate, GridFunction import numpy as np + from commics.integrators._details import IterationCounter self._Mag_eta.SetFromOther(self._Mag) @@ -336,11 +337,14 @@ class MPS(_integrator._Integrator, \ self._Mag_eta.m.vec.FV().NumPy()[:], succ \ = scipy.sparse.linalg.gmres(LHS, RHS , x0=None \ - , tol=self._solvetol, maxiter=4000, M=None) + , tol=self._solvetol, maxiter=4000, M=None, callback=self._callback) diff_gf.vec.FV().NumPy()[:] = self._Mag_eta.m.vec.FV().NumPy()-oldIterate itError = sqrt( Integrate(diff_cf*diff_cf, self._GetMesh()) ) - + + if isinstance(self._fixedPointItCounter, IterationCounter): + self._fixedPointCallback() + # TODO consider error in heff like in mpsPaper if (itError <= self._itertol): break -- GitLab