Open In App

CSES Solutions – Counting Bits

Last Updated : 04 Apr, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

Given an integer N (1 ≤ N ≤ 1015), the task is to count the number of one bits in the binary representations of integers between 1 and N.

Examples:

Input: N = 7
Output: 12
Explanation:

Binary representations of 1: 001-> Set bits = 1
Binary representations of 2: 010-> Set bits = 1
Binary representations of 3: 011-> Set bits = 2
Binary representations of 4: 100-> Set bits = 1
Binary representations of 5: 101-> Set bits = 2
Binary representations of 6: 110-> Set bits = 2
Binary representations of 7: 111-> Set bits = 3

So, there are total of 1 + 1 + 2 + 1 + 2 + 2 + 3 = 12 one bits.

Input: N = 5
Output: 7
Explanation:

Binary representations of 1: 001-> Set bits = 1
Binary representations of 2: 010-> Set bits = 1
Binary representations of 3: 011-> Set bits = 2
Binary representations of 4: 100-> Set bits = 1
Binary representations of 5: 101-> Set bits = 2

So, there are total of 1 + 1 + 2 + 1 + 2 = 7 one bits.

Approach: To solve the problem, follow the idea below:

If the input number N is of the form 2^b – 1 e.g., 1, 3, 7, 15.. etc, the number of set bits is b * 2^(b-1). This is because for all the numbers 0 to (2^b) – 1, if you complement and flip the list you end up with the same list (half the bits are on, half off). 
If the number does not have all set bits, then some position m is the position of leftmost set bit. The number of set bits in that position is n – (1 << m) + 1. The remaining set bits are in two parts:

  • The bits in the (m-1) positions down to the point where the leftmost bit becomes 0, and 
  • The 2^(m-1) numbers below that point, which is the closed form above.

Let us consider the number 6: 

0|0 0
0|0 1
0|1 0
0|1 1
-|--
1|0 0
1|0 1
1|1 0

The leftmost set bit is in position 2 (positions are considered starting from 0). If we mask that off what remains is 2 (the “1 0” in the right part of the last row.) So the number of bits in the 2nd position (the lower left box) is 3 (that is, 2 + 1). The set bits from 0-3 (the upper right box above) is 2*2^(2-1) = 4. The box in the lower right is the remaining bits we haven’t yet counted and is the number of set bits for all the numbers up to 2 (the value of the last entry in the lower right box) which can be figured recursively.  

Below is the implementation of the above algorithm:

C++
#include <bits/stdc++.h>
#define ll long long
using namespace std;

/* Returns position of leftmost set bit.
The rightmost position is considered
as 0 */
ll getLeftmostBit(ll n)
{
    ll m = 0;
    while (n > 1) {
        n = n >> 1;
        m++;
    }
    return m;
}

/* Given the position of previous leftmost
set bit in n (or an upper bound on
leftmost position) returns the new
position of leftmost set bit in n */
ll getNextLeftmostBit(ll n, ll m)
{
    ll temp = 1 << m;
    while (n < temp) {
        temp = temp >> 1;
        m--;
    }
    return m;
}

// Returns count of set bits present in
// all numbers from 1 to n
ll countSetBits(ll n)
{
    // Base Case: if n is 0, then set bit
    // count is 0
    if (n == 0)
        return 0;

    // Get the position of leftmost set
    // bit in n. This will be used as an
    // upper bound for next set bit function
    ll m = getLeftmostBit(n);

    // get position of next leftmost set bit
    m = getNextLeftmostBit(n, m);

    // If n is of the form 2^x-1, i.e., if n
    // is like 1, 3, 7, 15, 31, .. etc,
    // then we are done.
    // Since positions are considered starting
    // from 0, 1 is added to m
    if (n == (1LL << (m + 1)) - 1)
        return (m + 1) * (1 << m);

    // update n for next recursive call
    n = n - (1LL << m);
    return (n + 1) + countSetBits(n) + m * (1LL << (m - 1));
}

// Driver code
int main()
{
    ll n = 7;
    cout << countSetBits(n);
    return 0;
}
Java
import java.io.*;

public class GFG {

    // Returns position of leftmost set bit.
    // The rightmost position is considered as 0.
    static long getLeftmostBit(long n) {
        long m = 0;
        while (n > 1) {
            n = n >> 1;
            m++;
        }
        return m;
    }

    // Given the position of previous leftmost
    // set bit in n (or an upper bound on
    // leftmost position) returns the new
    // position of leftmost set bit in n
    static long getNextLeftmostBit(long n, long m) {
        long temp = 1 << m;
        while (n < temp) {
            temp = temp >> 1;
            m--;
        }
        return m;
    }

    // Returns count of set bits present in
    // all numbers from 1 to n
    static long countSetBits(long n) {
        // Base Case: if n is 0, then set bit
        // count is 0
        if (n == 0)
            return 0;

        // Get the position of leftmost set
        // bit in n. This will be used as an
        // upper bound for next set bit function
        long m = getLeftmostBit(n);

        // get position of next leftmost set bit
        m = getNextLeftmostBit(n, m);

        // If n is of the form 2^x-1, i.e., if n
        // is like 1, 3, 7, 15, 31, .. etc,
        // then we are done.
        // Since positions are considered starting
        // from 0, 1 is added to m
        if (n == (1L << (m + 1)) - 1)
            return (m + 1) * (1 << m);

        // update n for next recursive call
        n = n - (1L << m);
        return (n + 1) + countSetBits(n) + m * (1L << (m - 1));
    }

    // Driver code
    public static void main (String[] args) {
        long n = 7;
        System.out.println(countSetBits(n));
    }
}
Python
# Function to get the position of the leftmost set bit in n
def getLeftmostBit(n):
    m = 0
    # Iterate until n becomes 1
    while n > 1:
        # Right shift n by 1 bit
        n >>= 1
        # Increment m to track the position of the leftmost set bit
        m += 1
    return m  # Return the position of the leftmost set bit

# Function to get the position of the next leftmost set bit in n
def getNextLeftmostBit(n, m):
    temp = 1 << m  # Create a mask with only the m-th bit set
    # Iterate until n becomes greater than or equal to the mask
    while n < temp:
        # Right shift the mask by 1 bit
        temp >>= 1
        # Decrement m to track the position of the next leftmost set bit
        m -= 1
    return m  # Return the position of the next leftmost set bit

# Function to count the number of set bits in numbers from 1 to n
def countSetBits(n):
    if n == 0:
        return 0  # Base case: If n is 0, return 0

    m = getLeftmostBit(n)  # Get the position of the leftmost set bit in n
    m = getNextLeftmostBit(n, m)  # Get the position of the next leftmost set bit in n

    # If n is of the form 2^x - 1, return the count directly
    if n == (1 << (m + 1)) - 1:
        return (m + 1) * (1 << m)

    # Update n for the next recursive call
    n -= 1 << m
    # Return the count recursively calculated based on the updated n
    return (n + 1) + countSetBits(n) + m * (1 << (m - 1))

# Driver code
if __name__ == "__main__":
    n = int(input("Enter a number: "))
    print(countSetBits(n))  # Print the count of set bits in numbers from 1 to n
C#
using System;

class MainClass {
    // Returns position of leftmost set bit.
    // The rightmost position is considered as 0
    static long GetLeftmostBit(long n) {
        long m = 0;
        while (n > 1) {
            n = n >> 1;
            m++;
        }
        return m;
    }

    // Given the position of previous leftmost
    // set bit in n (or an upper bound on leftmost position)
    // returns the new position of leftmost set bit in n
    static long GetNextLeftmostBit(long n, long m) {
        long temp = 1 << (int)m;
        while (n < temp) {
            temp = temp >> 1;
            m--;
        }
        return m;
    }

    // Returns count of set bits present in
    // all numbers from 1 to n
    static long CountSetBits(long n) {
        // Base Case: if n is 0, then set bit count is 0
        if (n == 0)
            return 0;

        // Get the position of leftmost set bit in n.
        // This will be used as an upper bound for next set bit function
        long m = GetLeftmostBit(n);

        // Get position of next leftmost set bit
        m = GetNextLeftmostBit(n, m);

        // If n is of the form 2^x-1, i.e., if n
        // is like 1, 3, 7, 15, 31, .. etc,
        // then we are done.
        // Since positions are considered starting
        // from 0, 1 is added to m
        if (n == (1L << (int)(m + 1)) - 1)
            return (m + 1) * (1 << (int)m);

        // update n for next recursive call
        n = n - (1L << (int)m);
        return (n + 1) + CountSetBits(n) + m * (1L << (int)(m - 1));
    }

    // Driver code
    public static void Main(string[] args) {
        long n = 7;
        Console.WriteLine(CountSetBits(n));
    }
}
//Tus code is  contributed by Utkarsh.
JavaScript
function getLeftmostBit(n) {
    let m = 0;
    while (n > 1) {
        n = n >> 1;
        m++;
    }
    return m;
}

function getNextLeftmostBit(n, m) {
    let temp = 1 << m;
    while (n < temp) {
        temp = temp >> 1;
        m--;
    }
    return m;
}

function countSetBits(n) {
    if (n === 0)
        return 0;

    let m = getLeftmostBit(n);
    m = getNextLeftmostBit(n, m);

    if (n === (1 << (m + 1)) - 1)
        return (m + 1) * (1 << m);

    n = n - (1 << m);
    return (n + 1) + countSetBits(n) + m * (1 << (m - 1));
}

// Driver code
let n = 7;
console.log(countSetBits(n));

Output
12

Time Complexity: O(log(N))
Auxiliary Space: O(1)



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads