from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation, ComposedOperation
from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.src.operations.set_get.slice import Slice
from csdl_alpha.src.operations.set_get.loop_slice import VarSlice
import pytest
from csdl_alpha.utils.typing import VariableLike
import numpy as np
@set_properties(linear=True,)
class SetVarIndex(Operation):
'''
Elementwise setting of a slice s of a tensor x with another tensor y.
'''
def __init__(
self,
x:Variable,
y:Variable,
slice:VarSlice):
'''
Slice can be a tuple of slices or a single slice or list of index sets.
'''
super().__init__(x, y, *slice.vars)
self.name = 'set_index'
out_shapes = (x.shape,)
self.set_dense_outputs(out_shapes)
self.slice = slice
def compute_inline(self, x, y, *slice_args):
x_updated = x.copy()
x_updated[self.slice.evaluate(*slice_args)] = y
return x_updated
# # Set item could add over duplicate indices.
# # Maybe in the future?
# x_updated = x.copy()
# evaluated_slice = self.slice.evaluate(*slice_args)
# x_updated[evaluated_slice] = 0.0
# np.add.at(x_updated, evaluated_slice, y)
# return x_updated
def compute_jax(self, x, y, *slice_args):
if self.slice.var_slice is True:
from csdl_alpha.backends.jax.utils import fallback_to_inline_jax
return fallback_to_inline_jax(self, x, y, *slice_args)[0]
if y.size == 1:
return x.at[self.slice.jnpevaluate(*slice_args)].set(y[0])
else:
return x.at[self.slice.jnpevaluate(*slice_args)].set(y)
def evaluate_vjp(self, cotangents, x, y, *slice_args_and_outputs):
import csdl_alpha as csdl
x_updated = slice_args_and_outputs[-1]
slice_args = slice_args_and_outputs[:-1]
if cotangents.check(x):
cotangents.accumulate(x, cotangents[x_updated].set(self.slice, 0.0))
if cotangents.check(y):
if y.size == 1:
cotangents.accumulate(y, csdl.sum(cotangents[x_updated][self.slice]).reshape(y.shape))
else:
cotangents.accumulate(y, cotangents[x_updated][self.slice])
# zero out the cotangents for the slice arguments... Ideally, we avoid this
for slice_arg in slice_args:
if cotangents.check(slice_arg):
cotangents.accumulate(slice_arg, csdl.Variable(value = 0.0))
class BroadcastSetIndex(SetVarIndex):
'''
Setting all the elements of a slice s of a tensor x with a scalar y.
'''
def __init__(self, x, y, slice):
super().__init__(x, y, slice)
self.name = 'broadcast_set'
# class SparseSetIndex(ComposedOperation):
# def __init__(self,x,y):
# super().__init__(x,y)
# self.name = 'sparse_set'
# def compute_inline(self, x, y):
# pass
# class SparseBroadcastSetIndex(ComposedOperation):
# def __init__(self,x,y):
# super().__init__(x,y)
# self.name = 'sparse_broadcast_set'
# def compute_inline(self, x, y):
# pass
[docs]def set_index(x:Variable, s:Slice, y:VariableLike) -> Variable:
x = validate_and_variablize(x)
y = validate_and_variablize(y)
if y.size != 1:
import numpy as np
# TODO: index out of bounds error from csdl instead of numpy
slice_shape = np.zeros(x.shape)[s.evaluate_zeros()].shape
# from csdl_alpha.utils.slice import get_slice_shape
# slice_shape_ = get_slice_shape(s, x.shape)
# print(slice_shape_, slice_shape)
if slice_shape != y.shape:
raise ValueError(f'Slice shape does not match value shape. {slice_shape} != {y.shape}')
op = SetVarIndex(x, y, s)
else:
# TODO: use y.flatten() later once flatten() is implemented
# op = BroadcastSet(x, y.flatten(), s)
op = BroadcastSetIndex(x, y, s)
return op.finalize_and_return_outputs()
class TestSet(csdl_tests.CSDLTest):
def test_functionality(self,):
self.prep()
import csdl_alpha as csdl
import numpy as np
from csdl_alpha import slice
x_val = 3.0
y_val = 2.0
x = csdl.Variable(name = 'x', value = x_val)
y = csdl.Variable(name = 'y', value = y_val)
ind_0 = csdl.Variable(name = 'ind0', value = 0)
ind_1 = csdl.Variable(name = 'ind1', value = 1)
compare_values = []
# set a scalar slice with a scalar variable
x1 = x.set(slice[0:1], y)
x2 = x.set(slice[0], y)
x2_v = x.set(slice[ind_0], y)
t1 = np.array([2.])
compare_values += [csdl_tests.TestingPair(x1, t1)]
compare_values += [csdl_tests.TestingPair(x2, t1)]
compare_values += [csdl_tests.TestingPair(x2_v, t1)]
# set a scalar slice with a scalar constant
x3 = x.set(slice[0:1], 2.0)
x3_v = x.set(slice[[ind_0,]], 2.0)
compare_values += [csdl_tests.TestingPair(x3, t1)]
compare_values += [csdl_tests.TestingPair(x3_v, t1)]
z_val = 3.0*np.ones((3,2))
z = csdl.Variable(name = 'z', value = z_val)
# set a tensor slice with a tensor constant
z1 = z.set(slice[0:-1:1], 2.0*np.ones((2,2)))
z1_v = z.set(slice[[ind_0, ind_0+1]], 2.0*np.ones((2,2)))
t2 = np.array([[2.,2.],[2.,2.],[3.,3.]])
compare_values += [csdl_tests.TestingPair(z1, t2)]
compare_values += [csdl_tests.TestingPair(z1_v, t2)]
# set a tensor slice with a scalar constant
z2 = z.set(slice[0:-1:1], 2.0)
z2_v = z.set(slice[[ind_0, ind_0+1]], 2.0)
compare_values += [csdl_tests.TestingPair(z2, t2)]
compare_values += [csdl_tests.TestingPair(z2_v, t2)]
# set a tensor slice with a scalar variable
z3 = z.set(slice[0:-1:1], y)
z3_v = z.set(slice[[ind_0, ind_0+1]], y)
compare_values += [csdl_tests.TestingPair(z3, t2)]
compare_values += [csdl_tests.TestingPair(z3_v, t2)]
t_val = 2.0*np.ones((2,2))
t = csdl.Variable(name = 't', value = t_val)
# set a tensor slice with a tensor variable
z4 = z.set(slice[0:-1:1], t)
z4_v = z.set(slice[[ind_0, ind_0+1]], t_val)
z4_var = z.set(slice[ind_0:ind_0+2:1], t_val)
compare_values += [csdl_tests.TestingPair(z4, t2)]
compare_values += [csdl_tests.TestingPair(z4_v, t2)]
compare_values += [csdl_tests.TestingPair(z4_var, t2)]
t = csdl.Variable(name = 't', value = 2.0*np.ones((2,1)))
# set a tensor slice with a tensor variable
z5 = z.set((slice[0:-1, 1:2]), t)
z5_var = z.set((slice[0:-1, ind_1:ind_1+1]), t)
t3 = np.array([[3.,2.],[3.,2.],[3.,3.]])
compare_values += [csdl_tests.TestingPair(z5, t3)]
compare_values += [csdl_tests.TestingPair(z5_var, t3)]
t = csdl.Variable(name = 't', value = 2.0*np.ones((2,)))
# set a tensor slice at specific indices with a tensor variable
z6 = z.set(slice[([0,1], [1,1])], t)
z6_v = z.set(slice[([ind_0,1], [ind_1,ind_1])], t)
compare_values += [csdl_tests.TestingPair(z6, t3)]
compare_values += [csdl_tests.TestingPair(z6_v, t3)]
# set a tensor slice at specific indices with a scalar variable
z7 = z.set(slice[([0,1], [1,1])], y)
z7_v = z.set(slice[([ind_0,1], [ind_1,ind_1])], y)
compare_values += [csdl_tests.TestingPair(z7, t3)]
compare_values += [csdl_tests.TestingPair(z7_v, t3)]
# set a tensor slice at specific indices with a scalar constant
z8 = z.set(slice[([0,1], [1,1])], 2.0)
z8_v = z.set(slice[([ind_0,1], [ind_1,ind_1])], 2.0)
compare_values += [csdl_tests.TestingPair(z8, t3)]
compare_values += [csdl_tests.TestingPair(z8_v, t3)]
# slicing with CSDL variables
shape_2 = (10,9,8,7,6)
w_val = np.arange(np.prod(shape_2)).reshape(shape_2)
w = csdl.Variable(name = 'w', value = w_val)
int_1 = csdl.Variable(value = 2.0)
int_2 = int_1+3
x9 = w.set(csdl.slice[int_1:int_2, [1, 1, 1],[1, 2, 3], 0:2], 4.0)
x9_val = w_val.copy()
x9_val[2:5, [1, 1, 1],[1, 2, 3], 0:2] = 4.0
compare_values += [csdl_tests.TestingPair(x9, x9_val)]
x10 = w.set(csdl.slice[int_1:int_2, [1, 1, 1],[1, 2, 3], int_2:int_2+2], 11.0)
x10_val = w_val.copy()
x10_val[2:5, [1, 1, 1],[1, 2, 3], 5:7] = 11.0
compare_values += [csdl_tests.TestingPair(x10, x10_val)]
x11 = w.set(csdl.slice[int_1:int_2, [int_1, 1, 1],[int_2, int_2, int_1], int_2:int_2+2], 15)
x11_val = w_val.copy()
x11_val[2:5, [2, 1, 1],[5, 5, 2], 5:7] = 15
compare_values += [csdl_tests.TestingPair(x11, x11_val)]
x12 = w.set(csdl.slice[0:1, [int_1, 1, 1],[int_2, int_2, int_1], int_2:int_2+2],7.0)
x12_val = w_val.copy()
x12_val[0:1, [2, 1, 1],[5, 5, 2], 5:7] = 7.0
compare_values += [csdl_tests.TestingPair(x12, x12_val)]
# fixed/var and var step errors
with pytest.raises(TypeError):
z.set(slice[0:-1:ind_1], t)
with pytest.raises(TypeError):
z.set(slice[0:ind_1:1], 2.0)
with pytest.raises(TypeError):
z.set(slice[ind_0:-1:1], t)
with pytest.raises(ValueError):
z.set(slice[[0,0]], t)
with pytest.raises(TypeError):
z.set([0], t)
self.run_tests(compare_values = compare_values,turn_off_recorder=False, verify_derivatives=False)
# change indices values to make sure they are updated.
compare_values = []
ind_1.value = ind_1.value - 1
int_1.value = int_1.value - 1
current_graph = csdl.get_current_recorder().active_graph
current_graph.execute_inline()
t3 = np.array([[2.,3.],[2.,3.],[3.,3.]])
compare_values += [csdl_tests.TestingPair(z6_v, t3)]
compare_values += [csdl_tests.TestingPair(z7_v, t3)]
compare_values += [csdl_tests.TestingPair(z8_v, t3)]
comp_val = w_val.copy().astype(float)
comp_val[1:4, [1, 1, 1],[1, 2, 3], 0:2] = 4.0
compare_values += [csdl_tests.TestingPair(x9, comp_val)]
comp_val = w_val.copy()
comp_val[1:4, [1, 1, 1],[1, 2, 3], 4:6] = 11.0
compare_values += [csdl_tests.TestingPair(x10, comp_val)]
comp_val = w_val.copy()
comp_val[1:4, [1, 1, 1],[4, 4, 1], 4:6] = 15
compare_values += [csdl_tests.TestingPair(x11, comp_val)]
comp_val = w_val.copy()
comp_val[0:1, [1, 1, 1],[4, 4, 1], 4:6] = 7.0
compare_values += [csdl_tests.TestingPair(x12, comp_val)]
self.run_tests(compare_values = compare_values, verify_derivatives=False)
def test_derivs(self):
self.prep()
import csdl_alpha as csdl
import numpy as np
from csdl_alpha import slice
compare_values = []
shape_1 = (2,2,4)
x_val = np.arange(np.prod(shape_1)).reshape(shape_1)
x = csdl.Variable(name = 'x', value = x_val)
other = csdl.Variable(name = 'other', value = 2*np.ones((2,2)))
x6 = x.set(csdl.slice[0:2, [1, 0], 3], other)
x_val_temp = x_val.copy()
x_val_temp[0:2, [1, 0], 3] = 2.0
compare_values += [csdl_tests.TestingPair(x6, x_val_temp)]
other = csdl.Variable(name = 'other2', value = 2.0)
x6 = x.set(csdl.slice[[0, 1]], other)
x_val_temp = x_val.copy()
x_val_temp[[0, 1]] = 2.0
compare_values += [csdl_tests.TestingPair(x6, x_val_temp)]
x6 = x.set(csdl.slice[0:2, [1, 0, 0], [1, 0, 1]], value = 2*np.ones((2,3)))
x_val_temp = x_val.copy()
x_val_temp[0:2, [1, 0, 0], [1, 0, 1]] = 2.0
compare_values += [csdl_tests.TestingPair(x6, x_val_temp)]
x6 = x.set(csdl.slice[0:2, [1], [1, 0, 2]], value = 2*np.ones((2,3)))
x_val_temp = x_val.copy()
x_val_temp[0:2, [1], [1, 0, 2]] = 2.0
compare_values += [csdl_tests.TestingPair(x6, x_val_temp)]
x6 = x.set(csdl.slice[0:2, 1:2], value = 2)
x_val_temp = x_val.copy()
x_val_temp[0:2, 1:2] = 2.0
compare_values += [csdl_tests.TestingPair(x6, x_val_temp)]
self.run_tests(compare_values = compare_values, verify_derivatives=True)
def test_example(self,):
self.prep()
# docs:entry
import csdl_alpha as csdl
from csdl_alpha import slice
import numpy as np
recorder = csdl.build_new_recorder(inline = True)
recorder.start()
x = csdl.Variable(name = 'x', value = np.ones((3,2))*3.0)
y = csdl.Variable(name = 'y', value = 2.0)
z = csdl.Variable(name = 'z', value = np.ones((2,2))*2.0)
# set a scalar slice with a scalar variable
x1 = x.set(slice[0:-1], y)
print(x1.value)
# set a tensor slice with a scalar constant
x2 = x.set(slice[0:-1], 2)
print(x2.value)
# set a tensor slice with a tensor variable
x3 = x.set(slice[0:-1], z)
print(x3.value)
# docs:exit
compare_values = []
t = np.array([[2.,2.],[2.,2.],[3.,3.]])
compare_values += [csdl_tests.TestingPair(x1, t)]
compare_values += [csdl_tests.TestingPair(x2, t)]
compare_values += [csdl_tests.TestingPair(x3, t)]
self.run_tests(compare_values = compare_values, verify_derivatives=True)
if __name__ == '__main__':
test = TestSet()
test.test_functionality()
# test.test_example()
# test.test_derivs()