Source code for csdl_alpha.src.operations.cross

from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.operations.operation_subclasses import ComposedOperation
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.src.operations.derivatives.utils import get_uncontract_action

import numpy as np
import pytest

class Cross(Operation):
    '''
    Computes the cross product of two arrays of 3D vectors.
    '''
    def __init__(self, x, y, axis):
        super().__init__(x, y)
        self.name = 'cross'
        self.axis = axis
        out_shapes = (x.shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x, y):
        return np.cross(x, y, axis=self.axis)

    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return jnp.cross(x, y, axis=self.axis)

    def evaluate_vjp(self, cotangents, x, y, z):
        import csdl_alpha as csdl
        if cotangents.check(x):
            cotangents.accumulate(x, -csdl.cross(cotangents[z], y, axis=self.axis))
        if cotangents.check(y):
            cotangents.accumulate(y, -csdl.cross(x, cotangents[z], axis=self.axis))

# class Cross2D(Operation):
#     '''
#     Computes the cross product of two arrays of 2D vectors.
#     '''
#     def __init__(self, x, y, axis):
#         super().__init__(x, y)
#         self.name = '2Dcross'
#         self.axis = axis
#         out_shape = x.shape[:axis] + x.shape[axis+1:]
#         # Uncomment the line below to debug the blanket reshaping in set_inline_values in operation
#         # out_shape = x.shape[:axis] + (1,) + x.shape[axis+1:]
#         # Or uncomment the 2 lines below to debug the blanket reshaping in set_inline_values in operation
#         if x.size == 2:
#             out_shape = (1,)
#         out_shapes = (out_shape,)
#         self.set_dense_outputs(out_shapes)

#     def compute_inline(self, x, y):
#         return np.cross(x, y, axis=self.axis)

#     def evaluate_vjp(self, cotangents, x, y, z):
#         import csdl_alpha as csdl

#         expanded_cotangents = csdl.expand(cotangents[z], out_shape=x.shape, action = get_uncontract_action(x.shape, (self.axis, )))

#         if x.size == 2:
#             if cotangents.check(x):
#                 initial = csdl.Variable(value = np.zeros(x.shape))
#                 initial = initial.set(csdl.slice[], cotangents[z])
#                 initial = initial.set(csdl.slice[], cotangents[z])

#                 cotangents.accumulate(x, )
#             if cotangents.check(y):
#                 cotangents.accumulate(y, -csdl.cross(x, expanded_cotangents, axis=self.axis))

[docs]def cross(x, y, axis=None): ''' Computes the cross product of two arrays of 2D or 3D vectors. Parameters ---------- x : Variable or np.ndarray First input tensor of shape (l,...,2,...,n) or (l,...,3,...,n). y : Variable or np.ndarray Second input tensor of the same shape as x. axis : int, default=None Axis along which the 2D or 3D vectors are stored in the input tensors. Need not be specified if the input tensors are 1D vectors of size 2 or 3. Needs to be specified if the input tensors are 2D or higher dimensional tensors. Axis must be a non-negative integer less than the number of dimensions in the input tensors. Returns ------- Variable Tensor containing the cross product (x × y) of vectors in the input tensors. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([3.0, 4.0, 5.0])) >>> y = csdl.Variable(value = np.array([4.0, 5.0, 6.0])) >>> csdl.cross(x, y).value array([-1., 2., -1.]) >>> x = csdl.Variable(value = 3.0 * np.ones((2,3))) >>> y_val = np.arange(6).reshape(2,3) >>> csdl.cross(x, y_val, axis=1).value array([[ 3., -6., 3.], [ 3., -6., 3.]]) ''' x = validate_and_variablize(x) y = validate_and_variablize(y) if x.size == 1 or y.size == 1: raise ValueError("The input tensors must be atleast 1D vectors of size 3.") if x.shape != y.shape: raise ValueError("The input tensors must have the same shape.") if len(x.shape) == 1 and axis is None: axis = 0 if not isinstance(axis, int): raise ValueError("The axis argument must be an integer.") if axis < 0: raise ValueError("The axis argument must be a non-negative integer.") if axis >= len(x.shape): raise ValueError("The axis argument must be less than the number of dimensions in the input tensors.") if x.shape[axis] != 3 and x.shape[axis] != 2: raise ValueError("The input tensors must have a size of 3 along the specified axis.") if x.shape[axis] == 3: op = Cross(x, y, axis) elif x.shape[axis] == 2: raise NotImplementedError("Cross product of 2D vectors is not yet implemented.") op = Cross2D(x, y, axis) return op.finalize_and_return_outputs()
class TestCross(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np compare_values = [] recorder = csdl.build_new_recorder(inline = True) recorder.start() # x_val = np.array([3.0, 4.0]) # y_val = np.array([4.0, 5.0]) # x = csdl.Variable(value = x_val) # y = csdl.Variable(value = y_val) # s1 = csdl.cross(x, y) # t1 = np.cross(x_val, y_val).flatten() # compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] x_val = 3.0*np.ones((2,3)) y_val = np.arange(6).reshape(2,3) x = csdl.Variable(value = x_val) y = csdl.Variable(value = y_val) # s2 = csdl.cross(x, y_val, axis=0) # t2 = np.cross(x_val, y_val, axis=0) # compare_values += [csdl_tests.TestingPair(s2, t2, tag = 's2')] s3 = csdl.cross(x, y_val, axis=1) t3 = np.cross(x_val, y_val, axis=1) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] x_val = 3.0*np.ones((4, 3, 2)) y_val = np.arange(24).reshape(4, 3, 2) x = csdl.Variable(value = x_val) y = csdl.Variable(value = y_val) # s2 = csdl.cross(x, y_val, axis=0) # t2 = np.cross(x_val, y_val, axis=0) # compare_values += [csdl_tests.TestingPair(s2, t2, tag = 's2')] s3 = csdl.cross(x, y_val, axis=1) t3 = np.cross(x_val, y_val, axis=1) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] with pytest.raises(ValueError): s4 = csdl.cross(x, y_val, axis=0) t4 = np.cross(x_val, y_val, axis=0) self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_example(self,): self.docstest(cross) if __name__ == '__main__': test = TestCross() test.test_functionality() test.test_example()