I have some code that uses a lot of nested loops to perform some computation, and wanted to rewrite this code to use recursion for two reasons: 1) readability and 2) generalized functionality. The second reason is the most important as nested loops limit the use of this program to inputs that have a very specific length (7 in my case). However, I’m finding that the recursive implementation is slower than the iterative approach and would like to find potential solutions for this.
I have included the script to compare the two approaches below. The script is quite long but I wanted to address only a few functions here: find_combinations_recursive (called by find_combinations_arbitrary), find_combinations_iterative, _cartesian_product_arbitrary_rows_with_target_sum_recursive (called by _cartesian_product_arbitrary_rows_with_target_sum), _cartesian_product_seven_rows_with_target_sum (this is the iterative equivalent). According to my calculations, the recursive approach takes about 25 – 30% longer for the same inputs. Is there a way to speed up performance so that the recursive approach matches the speed of the iterative approach? The overall structure and algorithms used by the functions are quite similar so I’m guessing improvements to one of the recursive functions can be added to the other. Let me know if there is anything I can clarify!
# -*- coding: utf-8 -*-
import pickle
import numpy as np
from numba import njit, types, prange
from numba.typed import List
from numpy import array, int8
import time
#%% Define Functions
CACHE_FLAG = True
INT_TYPE = "int8"
NUMBA_INT_TYPE = getattr(types, INT_TYPE)
NUMPY_INT_TYPE = getattr(np, INT_TYPE)
# set type for segment length and indices
LEN_INT_TYPE = "int16"
NUMBA_LEN_INT_TYPE = getattr(types, LEN_INT_TYPE)
NUMPY_LEN_INT_TYPE = getattr(np, LEN_INT_TYPE)
@njit(cache=CACHE_FLAG)
def find_combinations_iterative(target, *arrays):
"""
Find all combinations of element indices from 7 arrays containing segment lengths that sum to the target update time.
Returns both the lengths that sum up to the target and the corresponding indices.
Args:
target (int): The target sum.
arrays (tuple of lists): Each list contains segment lengths.
Returns:
Tuple of two lists:
1. Lengths list - contains tuples of segment lengths that sum up to the target.
2. Indices list - contains tuples of indices in the original lists corresponding to the lengths.
"""
lengths_list = List()
indices_list = List()
for i in prange(len(arrays[0])):
sum_i = arrays[0][i]
if sum_i > target:
continue
for j in range(len(arrays[1])):
sum_j = sum_i + arrays[1][j]
if sum_j > target:
continue
for k in range(len(arrays[2])):
sum_k = sum_j + arrays[2][k]
if sum_k > target:
continue
for l in range(len(arrays[3])):
sum_l = sum_k + arrays[3][l]
if sum_l > target:
continue
for m in range(len(arrays[4])):
sum_m = sum_l + arrays[4][m]
if sum_m > target:
continue
for n in range(len(arrays[5])):
sum_n = sum_m + arrays[5][n]
if sum_n > target:
continue
for o in range(len(arrays[6])):
total = sum_n + arrays[6][o]
if total == target:
lengths_list.append(
(
arrays[0][i],
arrays[1][j],
arrays[2][k],
arrays[3][l],
arrays[4][m],
arrays[5][n],
arrays[6][o],
)
)
indices_list.append((i, j, k, l, m, n, o))
return lengths_list, indices_list
@njit(cache=CACHE_FLAG)
def find_combinations_arbitrary(list_, target):
'''
Find all combinations of segments lengths contained in list_ that sum to the target.
Parameters
----------
list_ : numba Typed List
numba Typed List containing arrays that contain candidate segment lengths.
target : int
target sum.
Returns
-------
results : Typed List of ListType[array(int64, 1d, C)]
Contains all combinations of segment lengths that sum up to target.
indices_results : Typed List of ListType[array(int64, 1d, C)]
Contains the indices of the segment lengths, as stored in the arrays contained in list_, that sum up to target.
'''
list_1 = List(list_)
current_combination = np.zeros(len(list_1), dtype=np.int64)
current_indices = np.zeros(len(list_1), dtype=np.int64)
results = List([current_combination])
indices_results = List([current_indices])
results.pop()
indices_results.pop() # just index out?
find_combinations_recursive(list_1, target, 0, current_combination, 0, current_indices, results, indices_results)
return results, indices_results
@njit(cache=CACHE_FLAG)
def find_combinations_recursive(lists, target, depth, current_combination, current_sum, current_indices, results, indices_results):
'''
Recursive implementation of the iterative approach used in find_combinations_iterative. Works for arbitrary number of segments.
Modifies the variables 'results' and 'indices_results' in place.
'''
if current_sum > target:
return
if depth == len(lists):
if np.sum(current_combination) == target:
results.append(current_combination.copy())
indices_results.append(current_indices.copy())
return
for i in range(len(lists[depth])):
current_combination[depth] = lists[depth][i]
current_indices[depth] = i
# if np.sum(current_combination) > target and depth < len(lists) - 1:
# if i + 1 < len(lists[depth]):
# find_combinations_recursive(lists, target, depth, current_combination, current_indices, results, indices_results, i + 1)
find_combinations_recursive(lists, target, depth + 1, current_combination, current_sum + lists[depth][i], current_indices, results, indices_results)
@njit(cache=CACHE_FLAG)
def _combine_valid_segments_arbitrary(
segments,
valid_combos_segment_lengths,
valid_combos_segment_idx,
target_IP,
target_UT,
second_from_end_voltage_choices=None,
):
"""
Generate all combinations of segments from N lists of segments, along with their lengths.
Args:
segments: The lists of segments to combine.
target_IP: The target sum that all values in a valid combination of segments must sum up to.
target_UT: The target length that all values in a valid combination of segments must produce.
second_from_end_voltage_choices: Specifies the values that the second from last element in a segment combination can assume. Defaults to None (no restriction)
Returns:
List[np.ndarray]: The combined list of segments.
List[List[int]]: The lengths of the segments in each combined array.
"""
all_valid_candidates = List()
all_valid_combo_lengths = List()
for combo_idx, valid_combo in enumerate(valid_combos_segment_idx):
seg_list = List()
for i_valid_combo_idx in range(len(valid_combo)):
# seg_list.append(segments[i_valid_combo_idx][i_valid_combo_idx])
seg_list.append(segments[i_valid_combo_idx][valid_combo[i_valid_combo_idx]])
# if valid_combos_segment_lengths[combo_idx][6] == 1 and second_from_end_voltage_choices is not None:
# combined_seg = _cartesian_product_seven_rows_with_target_sum(
# seg1, seg2, seg3, seg4, seg5, seg6, seg7, target_IP, allowed_values=second_from_end_voltage_choices
# )
combined_seg = _cartesian_product_arbitrary_rows_with_target_sum(seg_list, target_IP, target_UT, allowed_values = second_from_end_voltage_choices)
if combined_seg.size != 0:
all_valid_candidates.append(combined_seg)
all_valid_combo_lengths.append(valid_combos_segment_lengths[combo_idx])
# return all_valid_candidates
return all_valid_candidates, all_valid_combo_lengths
@njit(cache=CACHE_FLAG)
def _combine_valid_segments_seven(
s1,
s2,
s3,
s4,
s5,
s6,
s7,
valid_combos_segment_lengths,
valid_combos_segment_idx,
target_IP,
second_from_end_voltage_choices=None,
):
"""
Generate all combinations of segments from seven lists of segments, along with their lengths.
Args:
s1, s2, s3, s4, s5, s6, s7 (List[np.ndarray]): The lists of segments to combine.
target_IP: The target sum that all values in a valid combination of segments must sum up to.
valid_combos_segment_lengths: List of iterables containing valid segment length combinations
Returns:
List[np.ndarray]: The combined list of segments.
List[List[int]]: The lengths of the segments in each combined array.
"""
all_valid_candidates = List()
all_valid_combo_lengths = List()
for combo_idx, valid_combo in enumerate(valid_combos_segment_idx):
seg1 = s1[valid_combo[0]]
seg2 = s2[valid_combo[1]]
seg3 = s3[valid_combo[2]]
seg4 = s4[valid_combo[3]]
seg5 = s5[valid_combo[4]]
seg6 = s6[valid_combo[5]]
seg7 = s7[valid_combo[6]]
# if valid_combos_segment_lengths[combo_idx][6] == 1 and second_from_end_voltage_choices is not None:
combined_seg = _cartesian_product_seven_rows_with_target_sum(
seg1, seg2, seg3, seg4, seg5, seg6, seg7, target_IP, allowed_values=second_from_end_voltage_choices
)
if combined_seg.size != 0:
all_valid_candidates.append(combined_seg)
all_valid_combo_lengths.append(valid_combos_segment_lengths[combo_idx])
# return all_valid_candidates
return all_valid_candidates, all_valid_combo_lengths
@njit(cache=CACHE_FLAG)
def _cartesian_product_seven_rows_with_target_sum(
arr1, arr2, arr3, arr4, arr5, arr6, arr7, target_sum, allowed_values=None
):
rows = [arr.shape[0] for arr in [arr1, arr2, arr3, arr4, arr5, arr6, arr7]]
row_sums = [row_sum(arr) for arr in [arr2, arr4, arr6]] # only need to calculate sums for pulses and not gaps
result_list = List() # Using a Numba typed list
for i in range(rows[0]):
for j in range(rows[1]):
for k in range(rows[2]):
for l in range(rows[3]):
for m in range(rows[4]):
for n in range(rows[5]):
for o in range(rows[6]):
combined_sum = (
row_sums[0][j] + row_sums[1][l] + row_sums[2][n]
) # only need to calculate sums for pulses and not gaps
if combined_sum == target_sum:
combined_row = _concatenate_1d_arbitrary(
List([arr1[i], arr2[j], arr3[k], arr4[l], arr5[m], arr6[n], arr7[o]])
)
if allowed_values is None or combined_row[-2] in allowed_values:
result_list.append(combined_row)
# Determine the shape for the result array
if len(result_list) > 0:
result_shape = (len(result_list), len(result_list[0]))
else:
result_shape = (0, 0)
# Allocate the result array
result = np.empty(result_shape, dtype=NUMPY_INT_TYPE)
# Copy the elements from the list to the array
for idx, row in enumerate(result_list):
result[idx] = row
return result
@njit(cache=CACHE_FLAG)
def _cartesian_product_arbitrary_rows_with_target_sum(arrays, target, target_UT, allowed_values = None):
list_1 = List(arrays)
result_list = List([np.empty((target_UT,), dtype = NUMPY_INT_TYPE)])
result_list.pop()
rows = List()
row_sums = List()
for i_row in range(len(list_1)):
rows.append(list_1[i_row].shape[0])
if i_row % 2 == 0:
# gap pulse ignore
# row_sums.append(np.zeros(list_1[i_row].shape[0], dtype = NUMPY_LEN_INT_TYPE))
continue
else:
row_sums.append(row_sum(list_1[i_row]))
depth = 0
current_indices = np.zeros(len(list_1), dtype = np.int64)
_cartesian_product_arbitrary_rows_with_target_sum_recursive(list_1, depth, current_indices, rows, row_sums, result_list, target, allowed_values)
# result_list = _cartesian_product_arbitrary_rows_with_target_sum_iterative(list_1, rows, row_sums, target, allowed_values, result_list, current_indices)
# result_list = result_list[1:]
if len(result_list) > 0:
# result_list = result_list[1:]
result_shape = (len(result_list), len(result_list[0]))
else:
result_shape = (0, 0)
result = np.empty(result_shape, dtype=NUMPY_INT_TYPE)
# Copy the elements from the list to the array
for idx, row in enumerate(result_list):
result[idx] = row
return result
@njit(cache=CACHE_FLAG)
def _cartesian_product_arbitrary_rows_with_target_sum_recursive(arrs, depth, current_indices, rows, row_sums, result_list, target, allowed_values):
if depth == len(rows):
if _compute_combined_sum(row_sums, current_indices) == target:
combined_row = _compute_combined_row(arrs, current_indices)
if allowed_values is None or combined_row[-2] in allowed_values:
result_list.append(combined_row)
return
for i in range(rows[depth]):
current_indices[depth] = i
_cartesian_product_arbitrary_rows_with_target_sum_recursive(arrs, depth + 1, current_indices, rows, row_sums, result_list, target, allowed_values)
@njit(cache=CACHE_FLAG)
def _compute_combined_sum(row_sums, current_indices):
combined_sum = 0
tracker_idx = 1
for k_idx in range(len(current_indices)):
if k_idx % 2 == 0:
continue
else:
combined_sum = combined_sum + row_sums[k_idx - tracker_idx][current_indices[k_idx]]
tracker_idx += 1
return combined_sum
@njit(cache=CACHE_FLAG)
def _compute_combined_row(arrs, current_indices):
list_of_arrs = List()
for k_idx in range(len(current_indices)):
list_of_arrs.append(arrs[k_idx][current_indices[k_idx]])
combined_row = _concatenate_1d_arbitrary(list_of_arrs)
return combined_row
@njit(cache=CACHE_FLAG)
def row_sum(arr):
result = np.zeros(arr.shape[0], dtype=NUMPY_LEN_INT_TYPE)
for i in range(arr.shape[0]):
for j in range(arr.shape[1]):
result[i] += arr[i, j]
return result
@njit(cache=CACHE_FLAG)
def _concatenate_1d_arbitrary(arrays):
len_list = List()
for i in range(len(arrays)):
len_list.append(len(arrays[i]))
total_length = 0
for i in range(len(len_list)):
total_length = total_length + len_list[i]
result = np.empty(total_length, dtype=arrays[0].dtype)
current_idx = 0
for arr in arrays:
arr_len = len(arr)
result[current_idx : current_idx + arr_len] = arr
current_idx += arr_len
return result
#%% Sample data
s1_lens = array([0, 1])
s2_lens = array([0, 1, 2, 3])
s3_lens = array([0, 1])
s4_lens = array([0, 1, 2, 3])
s5_lens = array([0, 1, 2, 3, 4])
s6_lens = array([0, 1, 2])
s7_lens = array([1, 2, 3])
s1 = [array([[]], dtype=int8), array([[0]], dtype=int8)]
s2 = [array([[]], dtype=int8),
array([[-24]], dtype=int8),
array([[ -6, -24],
[-12, -24],
[-24, -24]], dtype=int8),
array([[ -6, -6, -24],
[ -6, -12, -24],
[ -6, -24, -24],
[-12, -12, -24],
[-12, -24, -24],
[-24, -24, -24]], dtype=int8)]
s3 = [array([[]], dtype=int8), array([[0]], dtype=int8)]
s4 = [array([[]], dtype=int8),
array([[24]], dtype=int8),
array([[ 6, 24],
[12, 24],
[24, 24]], dtype=int8),
array([[ 6, 24, 24],
[12, 24, 24],
[24, 24, 24]], dtype=int8)]
s5 = [array([[]], dtype=int8),
array([[0]], dtype=int8),
array([[0, 0]], dtype=int8),
array([[0, 0, 0]], dtype=int8),
array([[0, 0, 0, 0]], dtype=int8)]
s6 = [array([[]], dtype=int8),
array([[ -6],
[-12],
[-24]], dtype=int8),
array([[ -6, -6],
[ -6, -12],
[ -6, -24],
[-12, -6],
[-12, -12],
[-12, -24],
[-24, -6],
[-24, -12],
[-24, -24]], dtype=int8)]
s7 = [array([[0]], dtype=int8),
array([[0, 0]], dtype=int8),
array([[0, 0, 0]], dtype=int8)]
s_lens_list = List([s1_lens, s2_lens, s3_lens, s4_lens, s5_lens, s6_lens, s7_lens])
s_list = List([s1, s2, s3, s4, s5, s6, s7])
#%%
target_IP_Vfr = 0
update_time_fr = 7
#%%
start_time_find_combinations_recursive = time.time()
valid_combos_segment_lengths_arb, valid_combos_segment_idx_arb = find_combinations_arbitrary(s_lens_list, update_time_fr)
end_time_find_combinations_recursive = time.time()
print(fr'Time taken to generate valid combinations with recursive approach: {end_time_find_combinations_recursive - start_time_find_combinations_recursive}')
#%%
start_time_find_combinations_iterative = time.time()
valid_combos_segment_lengths_seven, valid_combos_segment_idx_seven = find_combinations_iterative(update_time_fr, s1_lens, s2_lens, s3_lens, s4_lens, s5_lens, s6_lens, s7_lens)
end_time_find_combinations_iterative = time.time()
print(fr'Time taken to generate valid combinations with iterative approach: {end_time_find_combinations_iterative - start_time_find_combinations_iterative}')
#%%
start_time_iterative_seven = time.time()
# all_valid_candidates, all_valid_combo_lengths = _combine_valid_segments_seven(
# s1,
# s2,
# s3,
# s4,
# s5,
# s6,
# s7,
# valid_combos_segment_lengths_seven,
# valid_combos_segment_idx_seven,
# target_IP_Vfr,
# second_from_end_voltage_choices=None,
# )
all_valid_candidates, all_valid_combo_lengths = _combine_valid_segments_seven(
List(s1),
List(s2),
List(s3),
List(s4),
List(s5),
List(s6),
List(s7),
valid_combos_segment_lengths_seven,
valid_combos_segment_idx_seven,
target_IP_Vfr,
second_from_end_voltage_choices=None,
)
end_time_iterative_seven = time.time()
print(fr'Time taken to generate valid candidates with iterative approach: {end_time_iterative_seven - start_time_iterative_seven}')
#%%
start_time_recursive_arbitrary = time.time()
all_valid_candidates_arb, all_valid_combo_lengths_arb = _combine_valid_segments_arbitrary(
s_list,
valid_combos_segment_lengths_arb,
valid_combos_segment_idx_arb,
target_IP_Vfr,
update_time_fr,
second_from_end_voltage_choices=None,
)
end_time_recursive_arbitrary = time.time()
print(fr'Time taken to generate valid candidates with recursive approach: {end_time_recursive_arbitrary - start_time_recursive_arbitrary}')
#%%
total_time_iterative = (end_time_iterative_seven - start_time_iterative_seven) + (end_time_find_combinations_iterative - start_time_find_combinations_iterative)
total_time_recursive = (end_time_recursive_arbitrary - start_time_recursive_arbitrary) + (end_time_find_combinations_recursive - start_time_find_combinations_recursive)
time_diff = total_time_recursive - total_time_iterative
percent_diff = ((total_time_recursive - total_time_iterative) / (total_time_iterative))*100 # assume iterative case is reference
print(fr'Time difference: {time_diff}')
print(fr'Percent difference: {percent_diff}')
4