Source code for csdl_alpha.src.operations.linalg.blockmat

from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests

import numpy as np
import pytest

class BlockMatrix(Operation):
    '''
    Assemble a block matrix from a list of matrices or a list of lists.
    '''

    def __init__(self, *args, num_row_blocks = None, shape=None):
        super().__init__(*args)
        self.name = 'block_matrix'
        self.num_row_blocks:list[int] = num_row_blocks
        out_shapes = (shape,)
        self.set_dense_outputs(out_shapes)

        # Build indices for each block
        self.indices = []
        if self.num_row_blocks is None:
            left_index = 0
            for arg in args:
                right_index = left_index + arg.shape[1]
                self.indices.append((left_index, right_index))
                left_index = right_index
        else:
            current_arg_index = 0
            lower_index = 0

            for cur_block_row_ind in range(len(self.num_row_blocks)):
                arg = args[current_arg_index]
                left_index = 0
                upper_index = lower_index + arg.shape[0]

                for cur_block_col_ind in range(self.num_row_blocks[cur_block_row_ind]):
                    # Current variable
                    arg = args[current_arg_index]
                    right_index = left_index + arg.shape[1]

                    # Save indices                  
                    self.indices.append((left_index, right_index, lower_index, upper_index))
                    
                    # update column indices
                    left_index = right_index
                    current_arg_index += 1 

                # update row indices
                lower_index = upper_index
            
            # print(self.indices)
            # exit()


    def compute_inline(self, *args):
        if self.num_row_blocks is None:
            return np.block([x for x in args])
        else:
            l = self.num_row_blocks
            row_idx = np.cumsum([0] + l)
            return np.block([list(args[row_idx[i]:row_idx[i+1]]) for i in range(len(l))])

    def compute_jax(self, *args):
        import jax.numpy as jnp
        if self.num_row_blocks is None:
            return jnp.block([x for x in args])
        else:
            l = self.num_row_blocks
            row_idx = np.cumsum([0] + l)
            return jnp.block([list(args[row_idx[i]:row_idx[i+1]]) for i in range(len(l))])

    def evaluate_vjp(self, cotangents, *inputs_and_block):
        inputs = inputs_and_block[:-1]
        block = inputs_and_block[-1]

        block_out = cotangents[block]
        for i, input in enumerate(inputs):
            if cotangents.check(input):
                if self.num_row_blocks is None:
                    left = self.indices[i][0]
                    right = self.indices[i][1]
                    cotangents.accumulate(input, block_out[:, left:right])
                else:
                    left = self.indices[i][0]
                    right = self.indices[i][1]
                    lower = self.indices[i][2]
                    upper = self.indices[i][3]
                    cotangents.accumulate(input, block_out[lower:upper, left:right])

[docs]def blockmat(l)->Variable: """ Assemble a block matrix from a list or list of lists of matrices. Parameters ---------- l : list or list of lists of Variable or np.ndarray objects List or a list of lists of matrices to assemble into a block matrix. Returns ------- Variable Block matrix assembled from the input list. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x_val = 3.0*np.ones((2,3)) >>> z_val = np.ones((1,5)) >>> x = csdl.Variable(name = 'x', value = x_val) >>> z = csdl.Variable(name = 'z', value = z_val) Create a block row matrix >>> b1 = csdl.blockmat([x, np.zeros((2,2))]) >>> b1.value array([[3., 3., 3., 0., 0.], [3., 3., 3., 0., 0.]]) Create a block matrix with block rows and columns >>> b2 = csdl.blockmat([[x, np.zeros((2,2))], [z]]) >>> b2.value array([[3., 3., 3., 0., 0.], [3., 3., 3., 0., 0.], [1., 1., 1., 1., 1.]]) """ list_in_list = any(isinstance(x, list) for x in l) all_are_list = all(isinstance(x, list) for x in l) if list_in_list != all_are_list: raise ValueError('List depths are mismatched.') if list_in_list: num_rows = sum([x[0].shape[0] for x in l]) num_cols = sum([x.shape[1] for x in l[0]]) for i, elem_l in enumerate(l): num_rows_current = elem_l[0].shape[0] for x in elem_l: if x.shape[0] != num_rows_current: raise ValueError(f'Number of columns are not the same for the blocks in the {i}th row. {x.shape[0]} given, {num_rows_current} expected') current_num_cols = sum([x.shape[1] for x in elem_l]) if current_num_cols != num_cols: raise ValueError(f'Total number of columns are not the same for the 0th and {i}th block rows. {current_num_cols} given, {num_cols} expected') args = [validate_and_variablize(y) for x in l for y in x ] num_row_blocks = [len(x) for x in l] else: num_rows = l[0].shape[0] num_cols = sum([x.shape[1] for x in l]) for x in l: if x.shape[0] != num_rows: raise ValueError(f'Number of rows are not the same for all the blocks. {x.shape[0]} given, {num_rows} expected') args = [validate_and_variablize(x) for x in l] num_row_blocks = None op = BlockMatrix(*args, num_row_blocks=num_row_blocks, shape=(num_rows, num_cols)) return op.finalize_and_return_outputs()
class TestBlockMat(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np recorder = csdl.build_new_recorder(inline = True) recorder.start() x_val = np.arange(6).reshape(2,3)+2.0 y_val = np.arange(10).reshape(2,5) z_val = np.arange(32).reshape(4,8)*0.5 w_val = np.ones((2,1)) x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) z = csdl.Variable(name = 'z', value = z_val) w = csdl.Variable(name = 'w', value = w_val) compare_values = [] # create a SINGLE ROW block row matrix b1 = csdl.blockmat([x,y]) t1 = np.block([x_val, y_val]) compare_values += [csdl_tests.TestingPair(b1, t1, tag = 'b1')] b1 = csdl.blockmat([x,y,w]) t1 = np.block([x_val, y_val, w_val]) compare_values += [csdl_tests.TestingPair(b1, t1, tag = 'b2')] b1 = csdl.blockmat([x]) t1 = np.block([x_val]) compare_values += [csdl_tests.TestingPair(b1, t1, tag = 'b3')] # Create a block matrix with block rows and columns b2 = csdl.blockmat([[x,y], [z]]) t2 = np.block([[x_val, y_val], [z_val]]) compare_values += [csdl_tests.TestingPair(b2, t2, tag = 'b4')] b2 = csdl.blockmat([[x,y], [z], [y,x]]) t2 = np.block([[x_val, y_val], [z_val], [y_val, x_val]]) compare_values += [csdl_tests.TestingPair(b2, t2, tag = 'b4')] b2 = csdl.blockmat([[x,x], [x, x]]) t2 = np.block([[x_val, x_val], [x_val, x_val]]) compare_values += [csdl_tests.TestingPair(b2, t2, tag = 'b5')] b2 = csdl.blockmat([[x.T()], [y.T()], [w.T()]]) t2 = np.block([[x_val.T], [y_val.T], [w_val.T]]) compare_values += [csdl_tests.TestingPair(b2, t2, tag = 'b5')] self.run_tests( compare_values = compare_values, verify_derivatives=True ) def test_errors(self,): self.prep() import csdl_alpha as csdl import numpy as np recorder = csdl.build_new_recorder(inline = True) recorder.start() x_val = np.arange(6).reshape(2,3)+2.0 y_val = np.arange(10).reshape(2,5) z_val = np.arange(32).reshape(4,8)*0.5 w_val = np.ones((2,1)) x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) z = csdl.Variable(name = 'z', value = z_val) w = csdl.Variable(name = 'w', value = w_val) with pytest.raises(ValueError): b1 = csdl.blockmat([x, y.T()]) with pytest.raises(ValueError): b1 = csdl.blockmat([[z], [x, y.T()]]) with pytest.raises(ValueError): b1 = csdl.blockmat([[z.T()], [x, y]]) with pytest.raises(ValueError): b1 = csdl.blockmat([[z], [x, y, x]]) def test_example(self,): self.docstest(blockmat) if __name__ == '__main__': test = TestBlockMat() test.test_functionality() test.test_errors() # test.test_example()