Source code for csdl_alpha.src.operations.tensor.reshape

from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation, ComposedOperation
from csdl_alpha.src.graph.operation import Operation, set_properties 
from csdl_alpha.src.graph.variable import Variable 
from csdl_alpha.utils.inputs import variablize

import csdl_alpha.utils.testing_utils as csdl_tests
import csdl_alpha.utils.error_utils as error_utils
import numpy as np

@set_properties(linear=True, diagonal_jacobian = True)
class Reshape(Operation):
    '''
    Elementwise addition of two tensors of the same shape.
    '''

    def __init__(self,x, shape):
        super().__init__(x)
        self.name = 'reshape'
        self.new_shape = shape
        self.set_dense_outputs((self.new_shape, ))

    def compute_inline(self, x):
        return x.reshape(self.new_shape)

    def evaluate_vjp(self, cotangents, x, out):
        if cotangents.check(x):
            cotangents.accumulate(x, cotangents[out].reshape(x.shape))

    def compute_jax(self, x):
        import jax.numpy as jnp
        return x.reshape(self.new_shape)  

[docs]def reshape(x:Variable, shape: tuple[int]) -> Variable: """Reshape a tensor x to a new shape. Parameters ---------- x : Variable shape : tuple[int] Returns ------- out: Variable Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0, 4.0])) >>> csdl.reshape(x, (2,2)).value array([[1., 2.], [3., 4.]]) >>> x.reshape((2,2)).value # same thing as above array([[1., 2.], [3., 4.]]) >>> x.flatten().value # reshapes to 1 dimension array([1., 2., 3., 4.]) """ # Given shape must be a tuple of integers try: error_utils.check_if_valid_shape(shape) except Exception as e: raise TypeError(f'Error with shape argument in reshape: {e}') # Translate -1 in shape to a valid shape size = np.prod(x.shape) new_shape = list(shape) found_negative = False for i in range(len(shape)): if shape[i] == -1: if found_negative: raise ValueError(f'Only one element of new shape can be -1') size_others = np.prod(shape)/(-1) new_shape[i] = int(size/size_others) found_negative = True shape = tuple(new_shape) # shape must be compatible with shape of variable x if x.size == np.prod(shape): op = Reshape(x, shape = shape) else: raise ValueError(f'Variable size and new shape do not match: ({x.size} != {np.prod(shape)})') return op.finalize_and_return_outputs()
class TestReshape(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = 3.0 x = csdl.Variable(name = 'x', value = x_val) x_val_large = np.ones((10,10,10)) x_large = csdl.Variable(name = 'x_large', value = x_val_large) compare_values = [] y = csdl.reshape(x, (1,1)) compare_values += [csdl_tests.TestingPair(y, np.array([[x_val]]))] y = csdl.reshape(x_large, (100,10)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((100,10)))] y = csdl.reshape(x_large, (1000,1)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((1000,1)))] y = csdl.reshape(y, y.shape) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((1000,1)))] y = x.reshape((1,1)) compare_values += [csdl_tests.TestingPair(y, np.array([[x_val]]))] y = x_large.reshape((100,10)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((100,10)))] y = x_large.reshape((1000,1)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((1000,1)))] y = y.reshape(y.shape) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((1000,1)))] y = y.reshape((1000,-1)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((1000,1)))] y = y.reshape((-1,10, 10)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((10,10,10)))] y = y.reshape((2,-1, 50)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((2,10,50)))] y = y.flatten() compare_values += [csdl_tests.TestingPair(y,x_val_large.flatten())] self.run_tests(compare_values = compare_values,) def test_derivatives(self): self.prep() import csdl_alpha as csdl import numpy as np x_val = 3.0 x = csdl.Variable(name = 'x', value = x_val) x_val_large = np.ones((3,1,2)) x_large = csdl.Variable(name = 'x_large', value = x_val_large) compare_values = [] y = x_large.flatten() compare_values += [csdl_tests.TestingPair(y,x_val_large.flatten())] y = x_large.reshape((6,1)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((6,1)))] y = x_large.reshape((2,1,3,1)) compare_values += [csdl_tests.TestingPair(y,x_val_large.reshape((2,1,3,1)))] self.run_tests(compare_values = compare_values,) def test_errors(self): self.prep() import csdl_alpha as csdl import numpy as np import pytest x_val_large = np.ones((10,10,10)) x_large = csdl.Variable(name = 'x_large', value = x_val_large) with pytest.raises(TypeError): y = csdl.reshape(x_large, 1000) with pytest.raises(TypeError): y = csdl.reshape(x_large, [1000]) with pytest.raises(TypeError): y = csdl.reshape(x_large, (1000.0,)) with pytest.raises(TypeError): y = csdl.reshape(x_large, (10, 100.0,)) with pytest.raises(ValueError): y = csdl.reshape(x_large, (10, 100, 10)) with pytest.raises(ValueError): y = csdl.reshape(x_large, (10, -1, -1)) with pytest.raises(ValueError): y = csdl.reshape(x_large, (-1, 10, -1)) def test_docstring(self): self.docstest(reshape) if __name__ == '__main__': test = TestReshape() test.test_functionality() test.test_errors()