Source code for csdl_alpha.src.operations.copyvar

from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation, ComposedOperation
from csdl_alpha.src.graph.operation import Operation, set_properties 
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.typing import VariableLike
import pytest

@set_properties(linear=True)
class CopyVar(ElementwiseOperation):

    def __init__(self,x):
        super().__init__(x)
        self.name = 'copy'

    def compute_inline(self, x):
        return x.copy()
    
    def compute_jax(self, x):
        return x+0.0

    def evaluate_vjp(self, cotangents, x, y):
        if cotangents.check(x):
            cotangents.accumulate(x, cotangents[y])

[docs]def copyvar(x:VariableLike)->Variable: """Return a copy of the input variable x. Parameters ---------- x : VariableLike Returns ------- out: Variable A new variable that represents the same value as x """ x = validate_and_variablize(x, raise_on_sparse=False) return CopyVar(x).finalize_and_return_outputs()
@set_properties(linear=True, elementwise=True) class CopyVarTo(Operation): def __init__(self,x,y): super().__init__(x) self.name = 'copyto' self.set_outputs([y]) def compute_inline(self, x): return x.copy() def compute_jax(self, x): return x+0.0 def evaluate_vjp(self, cotangents, x, y): if cotangents.check(x): cotangents.accumulate(x, cotangents[y]) def copyto(x:Variable, y:Variable)->Variable: """connect existing variauble x to y. Y MUST NOT BE COMPUTED FROM SOMEWHERE ELSE Parameters ---------- x : Variable y : Variable """ # Check if y is computed from somewhere else import csdl_alpha as csdl recorder = csdl.get_current_recorder() current_graph = recorder.active_graph if not y in current_graph.node_table: raise ValueError(f'y ({y.info()}) must be a variable in the current graph') if current_graph.in_degree(y) > 0: raise ValueError(f'y ({y.info()}) must not be computed from an operation already. ({current_graph.in_degree(y)} predecessors)') if y.shape != x.shape: raise ValueError(f'x and y must have the same shape. {x.shape} != {y.shape}') return CopyVarTo(x, y).finalize_and_return_outputs() class TestCopy(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = np.arange(10).reshape((2,5)) x = csdl.Variable(name = 'x', value = x_val) compare_values = [] # Variables: z = csdl.copyvar(x) compare_values += [csdl_tests.TestingPair(z, x_val)] self.run_tests(compare_values = compare_values, verify_derivatives=True) class TestCopyVarTo(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = np.arange(10).reshape((2,5)) x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = x_val*0.0) z = y+x copyto(x, y) current_graph = csdl.get_current_recorder().active_graph current_graph.execute_inline() compare_values = [] # Variables: compare_values += [csdl_tests.TestingPair(z, x_val*2.0)] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_error(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = np.arange(10).reshape((2,5)) x = csdl.Variable(name = 'x', value = x_val) x_0 = csdl.Variable(name = 'x', value = x_val[0,0]) y = csdl.Variable(name = 'y', value = x_val*0.0) z = y+x with pytest.raises(ValueError): copyto(x, z) with pytest.raises(ValueError): copyto(x, x_0) new_recorder = csdl.Recorder(inline = True) new_recorder.start() new_x = csdl.Variable(value = np.zeros((2,5))) new_recorder.stop() with pytest.raises(ValueError): copyto(x, new_x) if __name__ == '__main__': t = TestCopyVarTo() t.test_functionality() t.test_error()