class
SegmentTemplate:
def
__init__(
self
)
-
>
None
:
self
.start
=
0
self
.end
=
0
self
.count
=
0
self
.pendingUpdate
=
False
def
merge(
self
, leftSub, rightSub)
-
>
None
:
if
(leftSub.pendingUpdate):
self
.count
+
=
(leftSub.end
-
leftSub.start
+
1
-
leftSub.count)
else
:
self
.count
+
=
leftSub.count
if
(rightSub.pendingUpdate):
self
.count
+
=
(rightSub.end
-
rightSub.start
+
1
-
rightSub.count)
else
:
self
.count
+
=
rightSub.count
def
query(
self
)
-
>
int
:
return
self
.count
def
hasPendingUpdate(
self
)
-
>
bool
:
return
self
.pendingUpdate
def
applyPendingUpdate(
self
)
-
>
None
:
self
.count
=
self
.end
-
self
.start
+
1
-
self
.count
self
.pendingUpdate
=
False
def
addUpdate(
self
)
-
>
None
:
self
.pendingUpdate
=
not
self
.pendingUpdate
class
SegmentTree:
def
__init__(
self
, N)
-
>
None
:
self
.treeBuild
=
[SegmentTemplate()
for
_
in
range
(
self
._getSegmentTreeSize(N))]
self
.N
=
N
self
._buildTree(
1
,
0
, N
-
1
)
def
query(
self
, start:
int
, end:
int
)
-
>
int
:
result
=
self
._query(
1
, start, end)
return
result.query()
def
update(
self
, start:
int
, end:
int
)
-
>
None
:
self
._update(
1
, start, end)
def
_buildTree(
self
, stIndex:
int
, start:
int
, end:
int
)
-
>
None
:
self
.treeBuild[stIndex].start
=
start
self
.treeBuild[stIndex].end
=
end
if
(start
=
=
end):
return
mid
=
(start
+
end)
/
/
2
leftChildIndex
=
2
*
stIndex
rightChildIndex
=
leftChildIndex
+
1
self
._buildTree(leftChildIndex, start, mid)
self
._buildTree(rightChildIndex, mid
+
1
, end)
self
.treeBuild[stIndex].merge(
self
.treeBuild[leftChildIndex],
self
.treeBuild[rightChildIndex])
def
_getSegmentTreeSize(
self
, N:
int
)
-
>
int
:
size
=
1
while
size < N:
size <<
=
1
return
size <<
1
def
_query(
self
, stIndex:
int
,
start:
int
, end:
int
)
-
> SegmentTemplate:
if
(
self
.treeBuild[stIndex].start
=
=
start
and
self
.treeBuild[stIndex].end
=
=
end):
result
=
self
.treeBuild[stIndex]
if
(result.hasPendingUpdate()):
result.applyPendingUpdate()
return
result
mid
=
(
self
.treeBuild[stIndex].start
+
self
.treeBuild[stIndex].end)
/
/
2
leftChildIndex
=
stIndex
*
2
rightChildIndex
=
leftChildIndex
+
1
result
=
SegmentTemplate()
if
(start > mid):
result
=
self
._query(rightChildIndex, start, end)
elif
(end <
=
mid):
result
=
self
._query(leftChildIndex, start, end)
else
:
leftResult
=
self
._query(leftChildIndex, start, mid)
rightResult
=
self
._query(rightChildIndex, mid
+
1
, end)
result.start
=
leftResult.start
result.end
=
rightResult.end
result.merge(leftResult, rightResult)
if
(
self
.treeBuild[stIndex].hasPendingUpdate()):
result.addUpdate()
result.applyPendingUpdate()
return
result
def
_update(
self
, stIndex:
int
, start:
int
, end:
int
)
-
>
None
:
if
(
self
.treeBuild[stIndex].start
=
=
start
and
self
.treeBuild[stIndex].end
=
=
end):
self
.treeBuild[stIndex].addUpdate()
return
mid
=
(
self
.treeBuild[stIndex].start
+
self
.treeBuild[stIndex].end)
/
/
2
leftChildIndex
=
stIndex
*
2
rightChildIndex
=
leftChildIndex
+
1
if
(start > mid):
self
._update(rightChildIndex, start, end)
elif
(end <
=
mid):
self
._update(leftChildIndex, start, end)
else
:
self
._update(leftChildIndex, start, mid)
self
._update(rightChildIndex, mid
+
1
, end)
self
.treeBuild[stIndex].merge(
self
.treeBuild[leftChildIndex],
self
.treeBuild[rightChildIndex])
if
__name__
=
=
"__main__"
:
N
=
7
st
=
SegmentTree(N)
st.update(
2
,
5
)
print
(
"{}"
.
format
(st.query(
1
,
6
)))
st.update(
1
,
3
)
print
(
"{}"
.
format
(st.query(
1
,
6
)))