From ec1cd3c408dd85b897cc858fdf81d4f6aace5f7b Mon Sep 17 00:00:00 2001
From: Michael Innerberger <michael.innerberger@asc.tuwien.ac.at>
Date: Thu, 13 Jul 2023 15:07:49 -0400
Subject: [PATCH] Clean up PCG solver
---
lib/solvers/PcgSolver.m | 22 ++++++++++------------
1 file changed, 10 insertions(+), 12 deletions(-)
diff --git a/lib/solvers/PcgSolver.m b/lib/solvers/PcgSolver.m
index 1e8cd88..1a278c8 100644
--- a/lib/solvers/PcgSolver.m
+++ b/lib/solvers/PcgSolver.m
@@ -40,11 +40,11 @@ classdef PcgSolver < IterativeSolver
setupRhs@IterativeSolver(obj, b, varargin{:});
% initialize residual & search direction
- obj.residual = b - obj.A*obj.x;
+ obj.residual = b - obj.A * obj.x;
obj.Cresidual = obj.C.apply(obj.residual);
obj.searchDirection = obj.Cresidual;
- obj.residualCNorm = sum(obj.residual.*obj.Cresidual, 1);
- obj.normb = sqrt(sum(b.^2, 1));
+ obj.residualCNorm = sqrt(dot(obj.residual, obj.Cresidual, 1));
+ obj.normb = vecnorm(b, 2, 1);
end
end
@@ -52,8 +52,8 @@ classdef PcgSolver < IterativeSolver
methods (Access=public)
function tf = isConverged(obj)
tf = ((obj.iterationCount >= obj.maxIter) ...
- | (sqrt(obj.residualCNorm) < obj.tol) ...
- | (sqrt(obj.residualCNorm)./obj.normb < obj.tol));
+ | (obj.residualCNorm < obj.tol) ...
+ | (obj.residualCNorm ./ obj.normb < obj.tol));
end
end
@@ -64,24 +64,22 @@ classdef PcgSolver < IterativeSolver
% update solution
AsearchDirection = obj.A * obj.searchDirection(:,idx);
- if sum(obj.searchDirection(:,idx).*AsearchDirection, 1) < eps
+ dAd = dot(obj.searchDirection(:,idx), AsearchDirection, 1);
+ if dAd < eps
alpha = 1;
else
- alpha = obj.residualCNorm(:,idx) ./ sum(obj.searchDirection(:,idx).*AsearchDirection, 1);
+ alpha = obj.residualCNorm(:,idx).^2 ./ dAd;
end
obj.x(:,idx) = obj.x(:,idx) + alpha .* obj.searchDirection(:,idx);
- % DEBUG:
- % disp(['alpha = ', num2str(alpha)])
-
% update residual
obj.residual(:,idx) = obj.residual(:,idx) - alpha .* AsearchDirection;
obj.Cresidual(:,idx) = obj.C.apply(obj.residual(:,idx));
residualCNormOld = obj.residualCNorm(:,idx);
- obj.residualCNorm(:,idx) = sum(obj.residual(:,idx).*obj.Cresidual(:,idx), 1);
+ obj.residualCNorm(:,idx) = sqrt(dot(obj.residual(:,idx), obj.Cresidual(:,idx), 1));
% update search direction
- beta = obj.residualCNorm(:,idx) ./ residualCNormOld;
+ beta = (obj.residualCNorm(:,idx) ./ residualCNormOld).^2;
obj.searchDirection(:,idx) = obj.Cresidual(:,idx) + beta .* obj.searchDirection(:,idx);
end
end
--
GitLab