Source code for csdl_alpha.src.operations.linalg.linear_solve

from csdl_alpha.src.graph.operation import Operation, set_properties 
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.src.operations.linalg.linear_solvers import DirectSolver
from csdl_alpha.src.operations.linalg.linear_solvers.linear_solver import LinearSolver
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
from csdl_alpha.utils.typing import VariableLike
from csdl_alpha.src.operations.linalg.utils import process_matA_vecb
import pytest
import numpy as np

@set_properties()
class LinearSolve(Operation):
    def __init__(self, A:Variable, b:Variable, solver:LinearSolver = DirectSolver) -> 'LinearSolve':
        super().__init__(A,b)
        self.name = 'linear_solve'
        self.set_dense_outputs((b.shape,))

    def compute_inline(self, A, b):
        return np.linalg.solve(A, b)
    
    def compute_jax(self, A, b):
        import jax.numpy as jnp
        return jnp.linalg.solve(A, b)

    def evaluate_vjp(self, cotangents, A, b, x):
        import csdl_alpha as csdl

        solved_system =  csdl.solve_linear(A.T(), cotangents[x])
        if cotangents.check(b):
            cotangents.accumulate(b, solved_system)
        if cotangents.check(A):
            vjp = -csdl.outer(x,solved_system).T().reshape(A.shape)
            cotangents.accumulate(A, vjp)

[docs]def solve_linear( A:VariableLike, b:VariableLike, solver:LinearSolver = DirectSolver(), )->Variable: """Solve a linear system of equations Ax = b for x. Parameters ---------- A : VariableLike 2D matrix b : VariableLike 1D or 2D vector Returns ------- x: Variable 1D or 2D vector Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> A = csdl.Variable(value = np.array([[1, 2], [3, 4]])) >>> b = csdl.Variable(value = np.array([5, 6])) >>> csdl.solve_linear(A, b).value array([-4. , 4.5]) >>> recorder.stop() Specify different solvers: >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> A = csdl.Variable(value = np.array([[1, 2], [3, 4]])) >>> b = csdl.Variable(value = np.array([5, 6])) >>> csdl.solve_linear(A, b, solver = csdl.linear_solvers.DirectSolver()).value array([-4. , 4.5]) >>> recorder.stop() """ A = validate_and_variablize(A, raise_on_sparse=False) b = validate_and_variablize(b) A_mat, b_vec = process_matA_vecb(A, b) if A_mat.shape[0] != A_mat.shape[1]: raise ValueError(f"Matrix A must be square, but has shape {A.shape}") if not isinstance(solver, LinearSolver): raise TypeError(f"Solver must be a LinearSolver. Got {type(solver)}.") output = LinearSolve(A_mat, b_vec, solver).finalize_and_return_outputs() if len(b.shape) == 2: return output if len(b.shape) == 1: return output.reshape((output.size,))
class TestLinear(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np n = 4 # condition number too high? # A_shape = (n,n) # b_shape = (n,1) # A_val = (np.arange(np.prod(A_shape)).reshape(A_shape)+1.0)**2.0 # b_val = np.arange(np.prod(b_shape)).reshape(b_shape) main_diag = np.arange(n)+1 A_val = np.diag(main_diag) + np.diag(main_diag[:-1]+1, 1) + np.diag(main_diag[:-1]+2, -1) b_val = 2*np.arange(n) A = csdl.Variable(value = A_val) b = csdl.Variable(value = b_val) compare_values = [] x = csdl.solve_linear(A,b) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val))] x = csdl.solve_linear(A_val,b) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val))] x = csdl.solve_linear(A,b_val) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val))] b_shape = (4,) b_val = np.arange(np.prod(b_shape)).reshape(b_shape) b = csdl.Variable(value = b_val) x = csdl.solve_linear(A,b) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val).flatten())] x = csdl.solve_linear(A,b_val) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val).flatten())] x = csdl.solve_linear(A_val,b) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val).flatten())] x = csdl.solve_linear(A,b, solver = csdl.linear_solvers.ScipyKrylovSolver()) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val))] x = csdl.solve_linear(A_val,b, solver = csdl.linear_solvers.ScipyKrylovSolver()) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val))] x = csdl.solve_linear(A,b_val, solver = csdl.linear_solvers.DirectSolver()) compare_values += [csdl_tests.TestingPair(x, np.linalg.solve(A_val, b_val))] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_errors(self,): self.prep() import csdl_alpha as csdl import numpy as np A_shape = (4,4,3) b_shape = (4,1) A_val = (np.arange(np.prod(A_shape)).reshape(A_shape)+1.0)**2.0 b_val = np.arange(np.prod(b_shape)).reshape(b_shape) A = csdl.Variable(value = np.ones((2,2,3))) b = csdl.Variable(value = np.ones((2,2))) with pytest.raises(ValueError): x = csdl.solve_linear(A,b) A = csdl.Variable(value = np.ones((2,3))) b = csdl.Variable(value = np.ones((3,))) with pytest.raises(ValueError): x = csdl.solve_linear(A,b) A = csdl.Variable(value = np.ones((2,2))) b = csdl.Variable(value = np.ones((3,))) with pytest.raises(ValueError): x = csdl.solve_linear(A,b) A = csdl.Variable(value = np.ones((2,2))) b = csdl.Variable(value = np.ones((1,1))) with pytest.raises(ValueError): x = csdl.solve_linear(A,b) A = csdl.Variable(value = np.ones((2,3))) b = csdl.Variable(value = np.ones((3,2))) with pytest.raises(ValueError): x = csdl.solve_linear(A,b) A = csdl.Variable(value = np.ones((2,2))) b = csdl.Variable(value = np.ones((2,))) with pytest.raises(TypeError): x = csdl.solve_linear(A,b, 's') def test_docstrings(self): self.docstest(solve_linear) if __name__ == '__main__': t = TestLinear() t.test_functionality() t.test_errors() t.test_docstrings()