Source code for csdl_alpha.src.operations.linalg.matmat

from csdl_alpha.src.graph.operation import Operation, set_properties 
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import pytest
import scipy.sparse as sp
from csdl_alpha.utils.typing import VariableLike

@set_properties()
class MatMat(Operation):

    def __init__(self, A:Variable, B:Variable) -> 'MatMat':
        super().__init__(A,B)
        self.name = 'matmat'
        self.set_dense_outputs(((A.shape[0], B.shape[1]),))

    def compute_inline(self, A, B):
        return A @ B
    
    def compute_jax(self, A, x):
        return A @ x

    def evaluate_vjp(self, cotangents, A, B, C):
        import csdl_alpha as csdl
        if cotangents.check(B):
            cotangents.accumulate(B, A.T()@cotangents[C])
        if cotangents.check(A):
            cotangents.accumulate(A, cotangents[C]@B.T())

[docs]def matmat(A:VariableLike, B:VariableLike) -> Variable: """matrix-matrix multiplication A*B. The number of columns of A must be equal to the number of rows of x. Parameters ---------- A : Variable 2D matrix B : Variable 2D matrix (or 1D vector, in which case matvec is called instead) Returns ------- C: Variable 2D matrix Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> A = csdl.Variable(value = np.array([[1, 2], [3, 4], [5, 6]])) >>> B = csdl.Variable(value = np.array([[1, 2], [3, 4]])) >>> (A @ B).value array([[ 7., 10.], [15., 22.], [23., 34.]]) """ from csdl_alpha.src.operations.linalg.matvec import matvec from csdl_alpha.src.operations.sparse.sparse_matvec import matmat as sparse_matmat if isinstance(A, sp.spmatrix): return sparse_matmat(A, B) A = validate_and_variablize(A, raise_on_sparse = False) B = validate_and_variablize(B, raise_on_sparse = False) # checks: # - A must be 2D # - B must be 2D # - A.shape[1] == B.shape[0] if len(A.shape) != 2: raise ValueError(f"Matrix A must be 2D, but has shape {A.shape}") if A.shape[1] != B.shape[0]: raise ValueError(f"Number of columns of A must be equal to the number of rows of B. {A.shape[1]} != {B.shape[0]}") if len(B.shape) == 1: return matvec(A, B) if len(B.shape) != 2: raise ValueError(f"Matrix B must be 2D, but has shape {B.shape}") return MatMat(A, B).finalize_and_return_outputs()
class TestMatMat(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np import scipy.sparse as sp A_shape = (3,4) B_shape = (4,1) A_val = np.arange(np.prod(A_shape)).reshape(A_shape) B_val = np.arange(np.prod(B_shape)).reshape(B_shape) A = csdl.Variable(value = A_val) B = csdl.Variable(value = B_val) compare_values = [] C = csdl.matmat(A,B) compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] C = A@B compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] B_shape = (4,3) B_val = np.arange(np.prod(B_shape)).reshape(B_shape) B = csdl.Variable(value = B_val) C = csdl.matmat(A,B) compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] C = A@B compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] A_shape = (1,4) A_val = np.arange(np.prod(A_shape)).reshape(A_shape) A = csdl.Variable(value = A_val) C = csdl.matmat(A,B) compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] C = A@B compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] # with scalars B_shape = (4,3) B_val = np.arange(np.prod(B_shape)).reshape(B_shape) C = csdl.matmat(A,B_val) compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] C = A@B_val compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] A_shape = (1,4) A_val = np.arange(np.prod(A_shape)).reshape(A_shape) C = csdl.matmat(A_val,B) compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] C =A_val@B compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] # matvec A_shape = (3,4) B_shape = (4,) A_val = np.arange(np.prod(A_shape)).reshape(A_shape) B_val = np.arange(np.prod(B_shape)).reshape(B_shape) A = csdl.Variable(value = A_val) B = csdl.Variable(value = B_val) C = csdl.matmat(A,B) compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] C = A@B compare_values += [csdl_tests.TestingPair(C, A_val@B_val)] # sparse matrix # NOTE: temporary A_data = np.array([1, 2, 3, 4, 5, 6]) row = np.array([0, 0, 1, 1, 2, 2]) col = np.array([0, 1, 0, 1, 0, 1]) A = sp.csr_matrix((A_data, (row, col)), shape=A_shape) C = csdl.matmat(A,B) compare_values += [csdl_tests.TestingPair(C, A@B_val)] C = A@B compare_values += [csdl_tests.TestingPair(C, A@B_val)] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_errors(self): self.prep() import csdl_alpha as csdl import numpy as np A = csdl.Variable(value = np.ones((2,2,3))) B = csdl.Variable(value = np.ones((2,1))) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B A = csdl.Variable(value = np.ones((2,3))) B = csdl.Variable(value = np.ones((2,1))) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B A = csdl.Variable(value = np.ones((2,3))) B = csdl.Variable(value = np.ones((2,))) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B A = csdl.Variable(value = np.ones((2,3))) B = csdl.Variable(value = np.ones((3,4,4))) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B # check with numpy arrays A = csdl.Variable(value = np.ones((2,2,3))) B = np.ones((2,1)) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B A = csdl.Variable(value = np.ones((2,3))) B = np.ones((2,1)) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B A = np.ones((2,3)) B = csdl.Variable(value = np.ones((2,))) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B A = csdl.Variable(value = np.ones((2,3))) B = np.ones((3,4,4)) with pytest.raises(ValueError): C = csdl.matmat(A, B) with pytest.raises(ValueError): C = A@B def test_docsstrings(self): self.docstest(matmat) if __name__ == '__main__': t = TestMatMat() t.test_functionality() t.test_docsstrings() t.test_errors()