diff --git a/lib/solvers/PcgSolver.m b/lib/solvers/PcgSolver.m index 1e8cd88c10f3074eaea6cf1c8eaad9682989a789..1a278c82dda3a969218f6e1d6f730baa11c4bdcc 100644 --- a/lib/solvers/PcgSolver.m +++ b/lib/solvers/PcgSolver.m @@ -40,11 +40,11 @@ classdef PcgSolver < IterativeSolver setupRhs@IterativeSolver(obj, b, varargin{:}); % initialize residual & search direction - obj.residual = b - obj.A*obj.x; + obj.residual = b - obj.A * obj.x; obj.Cresidual = obj.C.apply(obj.residual); obj.searchDirection = obj.Cresidual; - obj.residualCNorm = sum(obj.residual.*obj.Cresidual, 1); - obj.normb = sqrt(sum(b.^2, 1)); + obj.residualCNorm = sqrt(dot(obj.residual, obj.Cresidual, 1)); + obj.normb = vecnorm(b, 2, 1); end end @@ -52,8 +52,8 @@ classdef PcgSolver < IterativeSolver methods (Access=public) function tf = isConverged(obj) tf = ((obj.iterationCount >= obj.maxIter) ... - | (sqrt(obj.residualCNorm) < obj.tol) ... - | (sqrt(obj.residualCNorm)./obj.normb < obj.tol)); + | (obj.residualCNorm < obj.tol) ... + | (obj.residualCNorm ./ obj.normb < obj.tol)); end end @@ -64,24 +64,22 @@ classdef PcgSolver < IterativeSolver % update solution AsearchDirection = obj.A * obj.searchDirection(:,idx); - if sum(obj.searchDirection(:,idx).*AsearchDirection, 1) < eps + dAd = dot(obj.searchDirection(:,idx), AsearchDirection, 1); + if dAd < eps alpha = 1; else - alpha = obj.residualCNorm(:,idx) ./ sum(obj.searchDirection(:,idx).*AsearchDirection, 1); + alpha = obj.residualCNorm(:,idx).^2 ./ dAd; end obj.x(:,idx) = obj.x(:,idx) + alpha .* obj.searchDirection(:,idx); - % DEBUG: - % disp(['alpha = ', num2str(alpha)]) - % update residual obj.residual(:,idx) = obj.residual(:,idx) - alpha .* AsearchDirection; obj.Cresidual(:,idx) = obj.C.apply(obj.residual(:,idx)); residualCNormOld = obj.residualCNorm(:,idx); - obj.residualCNorm(:,idx) = sum(obj.residual(:,idx).*obj.Cresidual(:,idx), 1); + obj.residualCNorm(:,idx) = sqrt(dot(obj.residual(:,idx), obj.Cresidual(:,idx), 1)); % update search direction - beta = obj.residualCNorm(:,idx) ./ residualCNormOld; + beta = (obj.residualCNorm(:,idx) ./ residualCNormOld).^2; obj.searchDirection(:,idx) = obj.Cresidual(:,idx) + beta .* obj.searchDirection(:,idx); end end