https://leetcode.com/problems/booking-concert-tickets-in-groups/
如何把一个问题转化成用线段树解决的问题,这个题是一个很好的示范。
个人解答
class Node:
def __init__(self, v = 0):
self.left = None
self.right = None
self.max = v
self.sum = v
class SegTree:
def __init__(self, n, initVal):
self.initVal = initVal
self.n = n
self.tree = self._build(0, n - 1, initVal)
def _build(self, l, r, initVal):
if l == r:
return Node(initVal)
node = Node()
mid = (l + r) // 2
node.left = self._build(l, mid, initVal)
node.right = self._build(mid + 1, r, initVal)
node.max = max(node.left.max, node.right.max)
node.sum = node.left.sum + node.right.sum
return node
def _decrease(self, node, l, r, i, v):
if l == r:
node.max -= v
node.sum -= v
return
mid = (l + r) // 2
if i <= mid:
self._decrease(node.left, l, mid, i, v)
else:
self._decrease(node.right, mid + 1, r, i, v)
node.max = max(node.left.max, node.right.max)
node.sum = node.left.sum + node.right.sum
def decrease(self, i, v):
self._decrease(self.tree, 0, self.n - 1, i, v)
def _queryAvaliable(self, node, l, r, target, maxRow):
if l > maxRow or node.max < target:
return []
if l == r:
# 要返回所在行以及当前行第一个空的index
return [l, self.initVal - node.max]
mid = (l + r) // 2
if node.left.max >= target:
return self._queryAvaliable(node.left, l, mid, target, maxRow)
return self._queryAvaliable(node.right, mid + 1, r, target, maxRow)
def queryAvaliable(self, target, maxRow):
return self._queryAvaliable(self.tree, 0, self.n - 1, target, maxRow)
def _querySum(self, node, l, r, maxRow):
if l > maxRow:
return 0
# 区间小于要请求的区间,直接返回!
if r <= maxRow or l == r:
return node.sum
mid = (l + r) // 2
return self._querySum(node.left, l, mid, maxRow) + self._querySum(node.right, mid + 1, r, maxRow)
def querySum(self, maxRow):
return self._querySum(self.tree, 0, self.n - 1, maxRow)
class BookMyShow:
def __init__(self, n: int, m: int):
self.tree = SegTree(n, m)
# 记录每行剩下多少,以及第一个remian > 0的行,用来scatter的时候,选取合适的位置更新
self.remain = [m] * n
self.n = n
self.remainIdx = 0
def gather(self, k: int, maxRow: int) -> List[int]:
res = self.tree.queryAvaliable(k, maxRow)
if res and res[0] >= 0:
self.tree.decrease(res[0], k)
self.remain[res[0]] -= k
return res
def scatter(self, k: int, maxRow: int) -> bool:
if self.tree.querySum(maxRow) < k:
return False
# 需要从最小的地方开始更新
i = self.remainIdx
while k:
if k >= self.remain[i]:
# remain为0
self.tree.decrease(i, self.remain[i])
k -= self.remain[i]
self.remain[i] = 0
i += 1
else:
self.tree.decrease(i, k)
self.remain[i] -= k
k = 0
self.remainIdx = i
return True
# Your BookMyShow object will be instantiated and called as such:
# obj = BookMyShow(n, m)
# param_1 = obj.gather(k,maxRow)
# param_2 = obj.scatter(k,maxRow)
题目分析
问题转化
本题看似是两个互不相干的操作,并不是常见的线段树查询/修改的典型模式。但仔细一看,不管是scatter还是gatter,都对应了两个行为:
- 查询一个符合条件的区间
- 修改一个/多个值
涉及到反复的修改和查询,线段树很适合了。
再想一想怎么利用区间状态:
- gather需要一个连续的空间,其实我们要找到剩余空间大于给定值的行,那么区间信息可以存区间里的最大remain,这样可以快速提前拒绝
- cluster可以分散,但题目的要求可以得知,分配是连续的,也就是连续区间的sum满足即可,那么区间也需要记录sum值
线段树
线段树的实现,其实不用想着太多的实现trick,就记住核心的思想,然后老老实实用指针实现就可以。
- 叶节点存具体的值,根节点存区间的值,不管是max/min/sum
- 每次查询/更新都是从root开始,二分处理
- 注意提前返回,如querySum的时候,根据区间判断是否可以直接返回