Source code for csdl_alpha.src.operations.linear_combination

from csdl_alpha.src.operations.operation_subclasses import ComposedOperation, check_expand_subgraphs
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.typing import VariableLike
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
import csdl_alpha as csdl

from scipy import sparse as sp
import numpy as np

# TODO: make start_weights, stop_weights csdl variables
class LinearCombination(ComposedOperation):
    def __init__(self, start, stop, start_weights, stop_weights, num_steps):
        super().__init__(start, stop)
        self.name = 'linear_combination'
        self.start_weights = start_weights
        self.stop_weights = stop_weights
        self.num_steps = num_steps

    def evaluate_composed(self, start, stop):
        return evaluate_linear_combination(start, stop, self.start_weights, self.stop_weights, self.num_steps)
    
def evaluate_linear_combination(start, stop, start_weights, stop_weights, num_steps):
    if len(start.shape) == 1:
        num_per_step = start.shape[0]
    else:
        num_per_step = np.prod(start.shape)

    map_num_outputs = num_steps*num_per_step
    map_num_inputs = num_per_step
    map_start = sp.lil_matrix((map_num_outputs, map_num_inputs))
    map_stop = sp.lil_matrix((map_num_outputs, map_num_inputs))
    for i in range(num_steps):
        start_step_map = (sp.eye(num_per_step)) * start_weights[i]
        map_start[i*num_per_step:(i+1)*num_per_step, :] = start_step_map

        stop_step_map = (sp.eye(num_per_step)) * stop_weights[i]
        map_stop[i*num_per_step:(i+1)*num_per_step, :] = stop_step_map

    num_per_step = num_per_step
    map_start = map_start.tocsc()
    map_stop = map_stop.tocsc()

    flattened_start = start.reshape((num_per_step,1))
    flattened_stop = stop.reshape((num_per_step,1))
    mapped_start_array = csdl.sparse.matvec(map_start, flattened_start)
    mapped_stop_array = csdl.sparse.matvec(map_stop, flattened_stop)

    flattened_output = mapped_start_array + mapped_stop_array

    out = flattened_output.reshape((num_steps,) + tuple(start.shape))
    return out

[docs]def linear_combination(start:VariableLike, stop:VariableLike, num_steps:int, start_weights:np.ndarray=None, stop_weights:np.ndarray=None)->Variable: """Elementwise linear combination of two tensors. Parameters ---------- start : VariableLike First tensor to be linearly combined. stop : VariableLike Second tensor to be linearly combined. start_weights : np.ndarray Weights for the first tensor in the linear combination. stop_weights : np.ndarray Weights for the second tensor in the linear combination. num_steps : int Number of steps in the linear combination. Returns ------- Variable Elementwise linear combination of the two tensors. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> start = csdl.Variable(value = np.array([1.0, 2.0, 3.0])) >>> stop = csdl.Variable(value = np.array([4.0, 5.0, 6.0])) >>> csdl.linear_combination(start, stop, 3).value array([[4. , 4.5, 4. ], [2.5, 3.5, 4.5], [1. , 2. , 3. ]]) """ start = validate_and_variablize(start) stop = validate_and_variablize(stop) csdl.check_parameter(num_steps, 'num_steps', types=int, lower=1) if start_weights is None: if num_steps == 1: start_weights = np.array([0.5]) else: start_weights = np.linspace(1, 0, num_steps) if stop_weights is None: if num_steps == 1: stop_weights = np.array([0.5]) else: stop_weights = np.linspace(0, 1, num_steps) if check_expand_subgraphs(): return evaluate_linear_combination(start, stop, start_weights, stop_weights, num_steps) else: op = LinearCombination(start, stop, start_weights, stop_weights, num_steps) return op.finalize_and_return_outputs()
class TestLinearCombination(csdl_tests.CSDLTest): def test_functionality(self): self.prep() x_val = np.array([1.0, 2.0, 3.0]) y_val = np.array([1.0, 2.0, 3.0]) x = csdl.Variable(value = x_val) y = csdl.Variable(value = y_val) compare_values = [] ls = csdl.linear_combination(x, y, 5) ls_val = np.linspace(x_val, y_val, 5) compare_values += [csdl_tests.TestingPair(ls, ls_val, tag = 'ls')] self.run_tests(compare_values = compare_values) # TODO: fix this (ask Mark?) # def test_example(self,): # self.docstest(linear_combination)