Source code for csdl_alpha.src.operations.add

from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation, ComposedOperation
from csdl_alpha.src.graph.operation import Operation, set_properties 
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.utils.typing import VariableLike

@set_properties(linear=True)
class Add(ElementwiseOperation):
    '''
    Elementwise addition of two tensors of the same shape.
    '''

    def __init__(self,x:Variable,y:Variable):
        super().__init__(x,y)
        self.name = 'add'

    def compute_inline(self, x, y):
        return x + y
    
    def compute_jax(self, x, y):
        return self.compute_inline(x, y)

    def compute_jax(self, x, y):
        return self.compute_inline(x, y)

    def evaluate_vjp(self, cotangents, x, y, z):
        if cotangents.check(x):
            cotangents.accumulate(x, cotangents[z])
        if cotangents.check(y):
            cotangents.accumulate(y, cotangents[z])

# TODO: Do we need a broadcast add? There's a lot of code duplication b/w both classes
class BroadcastAdd(Operation):
    '''
    Addition after the first input is broadcasted to the shape of the second input.
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'broadcast_add'
        out_shapes = (y.shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x, y):
        return x + y
    
    def compute_jax(self, x, y):
        z  = x + y
        return z

    def evaluate_vjp(self, cotangents, x, y, z):
        if cotangents.check(x):
            import csdl_alpha as csdl
            cotangents.accumulate(x, csdl.sum(cotangents[z]))
        if cotangents.check(y):
            cotangents.accumulate(y, cotangents[z])

class SparseAdd(ComposedOperation):

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'sparse_add'

    def compute_inline(self, x, y):
        pass
    
class SparseBroadcastAdd(ComposedOperation):

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'sparse_broadcast_add'

    def compute_inline(self, x, y):
        pass

[docs]def add(x:VariableLike,y:VariableLike)->Variable: """Elementwise addition of two tensors x and y. Parameters ---------- x : Variable y : Variable Returns ------- out: Variable Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0])) >>> y = csdl.Variable(value = np.array([4.0, 5.0, 6.0])) >>> csdl.add(x, y).value array([5., 7., 9.]) >>> (x + y).value # equivalent to the above array([5., 7., 9.]) >>> (x + 2.0).value # broadcasting is also supported array([3., 4., 5.]) """ x = validate_and_variablize(x, raise_on_sparse=False) y = validate_and_variablize(y, raise_on_sparse=False) if x.shape == y.shape: op = Add(x,y) elif x.size == 1: op = BroadcastAdd(x.flatten(),y) elif y.size == 1: op = BroadcastAdd(y.flatten(),x) else: raise ValueError(f'Shapes not compatible for add operation. x shape: {x.shape}, y shape: {y.shape}') # TODO: add later # if (not x.is_sparse) and (not y.is_sparse): # if x.shape == y.shape: # op = Add(x,y) # elif x.size == 1: # op = BroadcastAdd(x.flatten(),y) # elif y.size == 1: # op = BroadcastAdd(y.flatten(),x) # else: # raise ValueError(f'Shapes not compatible for sparse add operation. x shape: {x.shape}, y shape: {y.shape}') # elif (x.is_sparse) and (y.is_sparse): # if x.shape == y.shape: # op = SparseAdd(x,y) # else: # raise ValueError(f'Shapes not compatible for sparse add operation. x shape: {x.shape}, y shape: {y.shape}') # elif (x.is_sparse) and (not y.is_sparse): # if x.shape == y.shape: # op = SparseDenseAdd(x,y) # elif x.size == 1: # op = BroadcastAdd(x.to_dense().flatten(),y) # elif y.size == 1: # op = SparseBroadcastAdd(y.flatten(),x) # else: # raise ValueError(f'Shapes not compatible for sparse add operation. x shape: {x.shape}, y shape: {y.shape}') # elif (not x.is_sparse) and (y.is_sparse): # if x.shape == y.shape: # op = SparseDenseAdd(y,x) # elif x.size == 1: # op = BroadcastAdd(y.to_dense().flatten(),x) # elif y.size == 1: # op = SparseBroadcastAdd(x.flatten(),y) # else: # raise ValueError(f'Shapes not compatible for sparse add operation. x shape: {x.shape}, y shape: {y.shape}') return op.finalize_and_return_outputs()
class TestAdd(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = 3.0 y_val = 2.0 x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) compare_values = [] # add scalar variables s1 = csdl.add(x,y) t1 = np.array([x_val + y_val]) compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # add scalar constants s2 = csdl.add(3.0, 2.0) compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')] # add scalar constant and scalar variable s3 = csdl.add(3.0, y) compare_values += [csdl_tests.TestingPair(s3, t1, tag = 's3')] # add tensor constants s4 = csdl.add(3.0*np.ones((3,2)), 2.0*np.ones((3,2))) t2 = 5.0 * np.ones((3,2)) compare_values += [csdl_tests.TestingPair(s4, t2, tag = 's4')] # add scalar constant and tensor constant s5 = csdl.add(3.0, 2.0*np.ones((3,2))) compare_values += [csdl_tests.TestingPair(s5, t2, tag = 's5')] # add scalar variable and tensor constant s6 = csdl.add(x, 2.0*np.ones((3,2))) compare_values += [csdl_tests.TestingPair(s6, t2, tag = 's6')] z_val = 2.0*np.ones((3,2)) z = csdl.Variable(name = 'z', value = z_val) # add scalar variable and tensor variable s7 = csdl.add(x, z) compare_values += [csdl_tests.TestingPair(s7, t2, tag = 's7')] # add scalar constant and tensor variable s8 = csdl.add(3.0, z) compare_values += [csdl_tests.TestingPair(s8, t2, tag = 's8')] # add tensor variables s9 = csdl.add(x, z) compare_values += [csdl_tests.TestingPair(s9, t2, tag = 's9')] # add tensor variables s9 = csdl.add(csdl.Variable(value = np.ones((100,))), csdl.Variable(value = np.ones((100,)))) temp = s9[0] temp.add_name('temp') compare_values += [csdl_tests.TestingPair(temp, 2*np.ones((100,))[0].flatten(), tag = 's10')] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_errors(self): self.prep() import csdl_alpha as csdl import numpy as np x_val = np.ones((2,2)) y_val = np.ones((2,2)) x = csdl.SparseMatrix(name = 'x', value = x_val) y = csdl.SparseMatrix(name = 'y', value = y_val) x+y def test_docstring(self): self.docstest(add) if __name__ == '__main__': test = TestAdd() # test.overwrite_backend = 'jax' test.test_functionality() test.test_errors()