I am trying to solve box stacking problem as quickly as possible (faster than n^2). I came across a really nice piece of code, that uses binary indexed tree to solve it:
class BinaryIndexedTree:
def __init__(self, n):
self.n = n
self.arr = [0] * (n+1)
def get(self, i):
ans = 0
i += 1
while i > 0:
ans = max(ans, self.arr[i])
i -= i & -i
return ans
def set(self, i, x):
i += 1
while i <= self.n:
self.arr[i] = max(self.arr[i], x)
i += i & -i
def unset(self, i):
i += 1
while i <= self.n:
self.arr[i] = 0
i += i & -i
class Solution:
def maxHeight(self, cuboids: List[List[int]]) -> int:
for cuboid in cuboids:
cuboid.sort()
cuboid.extend([cuboid[-1]] * 2)
ranks = dict(map(reversed, enumerate(sorted(set(x[2] for x in cuboids)))))
for x in cuboids:
x[2] = ranks[x[2]]
cuboids.sort()
bit = BinaryIndexedTree(len(cuboids))
def solve(arr):
if len(arr) <= 1:
return arr
mi = len(arr) // 2
solve(arr[:mi])
l = sorted(arr[:mi], key=lambda x: x[1])
r = sorted(arr[mi:], key=lambda x: x[1])
l.append([0, inf, 0, 0])
r.append([0, inf, 0, 0])
il, ir = 0, 0
for _ in range(len(arr)):
li = l[il]
ri = r[ir]
if li[1] <= ri[1]:
bit.set(li[2], li[4])
il += 1
else:
ri[4] = max(ri[4], ri[3] + bit.get(ri[2]))
ir += 1
for li in l:
bit.unset(li[2])
solve(arr[mi:])
solve(cuboids)
return max(x[4] for x in cuboids)
The problem being is that i do not understand the code, as I am not that familiar with python (and, frankly speaking, indexed tree as well). Explain the algorithm.
2