Source code for csdl_alpha.src.operations.tensor.expand

from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.operations.operation_subclasses import ComposedOperation
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.utils.typing import VariableLike
import warnings

import numpy as np

@set_properties(linear=True)
class ScalarExpand(Operation):
    '''
    Expands the input scalar to the specified `out_shape`.
    '''
    def __init__(self, x,  out_shape):
        super().__init__(x)
        self.name = 'scalar_expand'
        self.out_shape = out_shape
        out_shapes = (out_shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x):
        return np.broadcast_to(x, self.out_shape)
    
    def compute_jax(self, x):
        import jax.numpy as jnp
        return jnp.broadcast_to(x, self.out_shape)  
    
    def evaluate_vjp(self, cotangents, x, y):
        if cotangents.check(x):
            import csdl_alpha as csdl
            cotangents.accumulate(x, csdl.sum(cotangents[y]))
        
@set_properties(linear=True)
class TensorExpand(Operation):
    '''
    Expands the input tensor to the specified `out_shape` by 
    repeating the tensor along certain axes determined fom the `action` argument.
    '''
    def __init__(self, x, out_shape, ones_shape, einsum_str):
        super().__init__(x)
        self.name = 'tensor_expand'
        self.out_shape = out_shape
        self.ones_shape = ones_shape
        self.einsum_str = einsum_str

        out_shapes = (out_shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x):
        # NOTE : if csdl.einsum is implemented using csdl.[sum, expand, reorder_axes, mult] later,
        # then the line below should never call csdl.einsum since it just creates recursive calls.
        # print(self.einsum_str, (self.ones_shape))
        # exit()
        return np.einsum(self.einsum_str, x, np.ones(self.ones_shape))
    
    def compute_jax(self, x):
        import jax.numpy as jnp
        # NOTE : if csdl.einsum is implemented using csdl.[sum, expand, reorder_axes, mult] later,
        # then the line below should never call csdl.einsum since it just creates recursive calls.
        return jnp.einsum(self.einsum_str, x, jnp.ones(self.ones_shape))

    def evaluate_vjp(self, cotangents, x, y):
        if cotangents.check(x):
            import csdl_alpha as csdl
            in_str, out_str = self.einsum_str.split('->')
            in_str, ones_str = in_str.split(',')
            
            sum_str = out_str + '->' + in_str
            vjp = csdl.einsum(cotangents[y], action=sum_str)
            cotangents.accumulate(x, vjp)

    def evaluate_jacobian(self, x):
        # NOTE : if csdl.einsum is implemented using csdl.[sum, expand, reorder_axes, mult] later,
        # then the line below should never call csdl.einsum since it just creates recursive calls.
        rows = np.arange(np.prod(self.out_shape)).reshape(self.out_shape)
        cols = np.einsum(self.einsum_str, np.arange(x.size).reshape(x.shape), np.ones(self.ones_shape))

        return csdl.Constant(rows=rows.flatten(), cols=cols.flatten(), val = 1.)
        
[docs]def expand(x, out_shape, action=None): ''' Expands the input scalar/tensor to the specified `out_shape` by repeating the tensor along certain axes determined fom the `action` argument. For example, `action='i->ijk'` will expand a 1D tensor to a 3D tensor by repeating the input tensor along two new axes. The `action` argument is optional if the input is a scalar since the scalar will be simply broadcasted to the specified `out_shape`. Parameters ---------- x : VariableLike Input scalar/tensor that needs to be expanded. out_shape : tuple of int Desired shape of the expanded output tensor. action : str, default=None Specifies the action to be taken when expanding the tensor, e.g.,`'i->ij'` expands a vector to a matrix by repeating the input vector rowwise. Returns ------- Variable Expanded output tensor as per the specified `out_shape` and `action`. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = 3.0) >>> y1 = csdl.expand(x, out_shape=(2,3)) >>> y1.value array([[3., 3., 3.], [3., 3., 3.]]) >>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0])) >>> y2 = csdl.expand(x, out_shape=(2,3), action='i->ji') >>> y2.value array([[1., 2., 3.], [1., 2., 3.]]) >>> y3 = csdl.expand(x, out_shape=(3,2), action='i->ij') >>> y3.value array([[1., 1.], [2., 2.], [3., 3.]]) >>> y4 = csdl.expand(x, out_shape=(4,3,2), action='i->lij') >>> y4.value array([[[1., 1.], [2., 2.], [3., 3.]], <BLANKLINE> [[1., 1.], [2., 2.], [3., 3.]], <BLANKLINE> [[1., 1.], [2., 2.], [3., 3.]], <BLANKLINE> [[1., 1.], [2., 2.], [3., 3.]]]) ''' x = variablize(x) if not isinstance(out_shape, tuple): raise ValueError('"out_shape" must be a tuple.') if x.size != 1: if action is None: raise ValueError('Cannot expand a tensor without "action" specified.') else: if not isinstance(action, str): raise TypeError('"action" must be a string.') if '->' not in action: raise ValueError('Invalid action string. Use "->" to separate the input and output subscripts.') in_str, out_str = action.split('->') in_shape = x.shape if len(in_str) != len(in_shape): raise ValueError(f'Input tensor shape {in_shape} does not match the input string \'{in_str}\' in the action.') if len(out_str) != len(out_shape): raise ValueError('Output tensor shape does not match the output string in the action.') if not all(in_str.count(char) == 1 for char in in_str): raise ValueError('Each character in the input string must appear exactly once.') if not all(out_str.count(char) == 1 for char in out_str): raise ValueError('Each character in the output string must appear exactly once.') if not all(out_str.count(char) == 1 for char in in_str): raise ValueError('Each character in the input string must appear exactly once in the output string.') if in_shape != tuple([out_shape[out_str.index(char)] for char in in_str]): raise ValueError(f'Input tensor shape {in_shape} is not compatible with the output shape {out_shape} specified in the action.') ones_str = ''.join([char for char in out_str if char not in in_str]) ones_shape = tuple([out_shape[out_str.index(char)] for char in ones_str]) einsum_str = in_str + ',' + ones_str + '->' + out_str op = TensorExpand(x, out_shape, ones_shape, einsum_str) else: if action is not None: warnings.warn('"action" will have no effect when expanding a scalar.') x = x.flatten() op = ScalarExpand(x, out_shape) return op.finalize_and_return_outputs()
class TestExpand(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np recorder = csdl.build_new_recorder(inline = True) recorder.start() x_val = 3.0 y_val = np.array([1.0, 2.0, 3.0]) y_tensor_val = np.arange(60).reshape(3,4,5) x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) y_tensor = csdl.Variable(name = 'yt', value = y_tensor_val) compare_values = [] # expand a scalar constant s1 = csdl.expand(x_val, out_shape=(2,3,4)) t1 = x_val * np.ones((2,3,4)) compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # expand a scalar variable s2 = csdl.expand(x, out_shape=(2,3,4)) compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')] # expand a tensor variable s3 = csdl.expand(y_tensor, out_shape=(4,3,4,2,5), action='ijk->aijbk') t3 = np.einsum('ijk,aijbk->aijbk', y_tensor_val, np.ones((4,3,4,2,5))) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] # expand a tensor variable s3 = csdl.expand(y_tensor, out_shape=(5,2,3,1,4), action='ijk->kaibj') t3 = np.einsum('ijk,kaibj->kaibj', y_tensor_val, np.ones((5,2,3,1,4))) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] # expand a vector variable s3 = csdl.expand(y, out_shape=(3,4), action='j->jk') t3 = np.einsum('j,jk->jk', y_val, np.ones((3,4))) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] # expand a vector constant s4 = csdl.expand(y_val, out_shape=(2,3,4), action='j->ijk') t4 = np.einsum('j,ijk->ijk', y_val, np.ones((2,3,4))) compare_values += [csdl_tests.TestingPair(s4, t4, tag = 's4')] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_example(self,): self.docstest(expand) if __name__ == '__main__': test = TestExpand() test.test_functionality() test.test_example()