from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.operations.operation_subclasses import ComposedOperation
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.utils.typing import VariableLike
import numpy as np
@set_properties(linear=True)
class ReorderAxes(Operation):
'''
Reorders the axes of the input tensor as per the specified `action`.
'''
def __init__(self, x, out_shape, in_str, out_str):
super().__init__(x)
self.name = 'reorder_axes'
out_shapes = (out_shape,)
self.set_dense_outputs(out_shapes)
out_axes = tuple([in_str.index(char) for char in out_str])
self.out_axes = out_axes
self.in_str = in_str
self.out_str = out_str
def compute_inline(self, x):
return np.transpose(x, self.out_axes)
def compute_jax(self, x):
import jax.numpy as jnp
return jnp.transpose(x, self.out_axes)
def evaluate_vjp(self, cotangents, x, z):
if cotangents.check(x):
import csdl_alpha as csdl
cotangents.accumulate(x, csdl.reorder_axes(cotangents[z], action=self.out_str + '->' + self.in_str))
[docs]def reorder_axes(x, action):
'''
Reorders the axes of the input tensor as per the specified `action`.
For example, `action='ijk->kji'` will transpose the input 3D tensor.
The `action` argument is optional if the input is a scalar since the
scalar will be simply broadcasted to the specified `out_shape`.
Parameters
----------
x : Variable or np.ndarray
Input tensor that needs to have its axes reordered.
action : str
Specifies how the axes of the input tensor needs to be reordered,
e.g.,`'ij->ji'` transposes the input matrix.
Returns
-------
Variable
Axes-reordered output tensor as per specfied `action`.
Examples
--------
>>> recorder = csdl.Recorder(inline = True)
>>> recorder.start()
>>> x_val = np.array([[1., 2., 3.], \
[4., 5., 6.]])
>>> x = csdl.Variable(value = x_val)
>>> y1 = csdl.reorder_axes(x, action='ij->ji')
>>> y1.value
array([[1., 4.],
[2., 5.],
[3., 6.]])
Reorder the axes of a 3D tensor:
>>> x_val = np.arange(24).reshape(2,3,4)
>>> x_val
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
<BLANKLINE>
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> y2 = csdl.reorder_axes(x_val, action='ijk->kij')
>>> y2.value
array([[[ 0., 4., 8.],
[12., 16., 20.]],
<BLANKLINE>
[[ 1., 5., 9.],
[13., 17., 21.]],
<BLANKLINE>
[[ 2., 6., 10.],
[14., 18., 22.]],
<BLANKLINE>
[[ 3., 7., 11.],
[15., 19., 23.]]])
'''
x = variablize(x)
if x.size == 1:
raise ValueError('Cannot reorder axes of a scalar.')
if action is None:
raise ValueError('Cannot reorder axes of a tensor without "action" specified.')
if not isinstance(action, str):
raise ValueError('"action" must be a string.')
if '->' not in action:
raise ValueError('Invalid action string. Use "->" to separate the input and output subscripts.')
in_str, out_str = action.split('->')
in_shape = x.shape
if len(in_str) != len(in_shape):
raise ValueError('Input tensor shape does not match the input string in the action.')
if len(out_str) != len(in_str):
raise ValueError('Number of axes in the input and output must be the same.')
if not all(in_str.count(char) == 1 for char in in_str):
raise ValueError('Each character in the input string must appear exactly once.')
if not all(out_str.count(char) == 1 for char in out_str):
raise ValueError('Each character in the output string must appear exactly once.')
if not all(out_str.count(char) == 1 for char in in_str):
raise ValueError('Each character in the input string must appear exactly once in the output string.')
out_shape = tuple([in_shape[in_str.index(char)] for char in out_str])
op = ReorderAxes(x, out_shape, in_str, out_str)
return op.finalize_and_return_outputs()
class TestReorderAxes(csdl_tests.CSDLTest):
def test_functionality(self,):
self.prep()
import csdl_alpha as csdl
import numpy as np
recorder = csdl.build_new_recorder(inline = True)
recorder.start()
x_val = np.array([[1., 2., 3.], \
[4., 5., 6.]])
y_val = np.arange(24).reshape(2,3,4)
x = csdl.Variable(name = 'x', value = x_val)
compare_values = []
# transpose of a matrix variable
s1 = csdl.reorder_axes(x, action='ij->ji')
t1 = np.transpose(x_val)
compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')]
# reorder axes of a 3D tensor constant
s2 = csdl.reorder_axes(y_val, action='ijk->kij')
t2 = np.transpose(y_val, (2,0,1))
compare_values += [csdl_tests.TestingPair(s2, t2, tag = 's2')]
self.run_tests(compare_values = compare_values, verify_derivatives=True)
def test_example(self,):
self.docstest(reorder_axes)
if __name__ == '__main__':
test = TestReorderAxes()
test.test_functionality()
test.test_example()