0/1 Knapsack Problem

and space optimized Dynamic Programming Solution using tabulation and memoization in Rust Language.

Introduction

0/1 Knapsack Problem is a classical Dynamic Programming problem. In this problem, you are a thief :D

So, in 0/1 knapsack problem, you are given a list of items, along with their weights and profits if you sell them. But you have a limited capacity knapsack, say W, that you can use to take the items with you.

Also, you have whole items, that is, either you can take the item or leave it and you can take 1 item only once.

You have to maximize the profit you can earn by taking the items in the knapsack.

For Example :

  • Item Weight = [1, 4, 5, 7]
  • Profits = [3, 6, 9, 11]
  • W = 11

Answer : 18, ( Taking items weights [1, 4, 5])

Recursive Solution

The recursive solution is pretty simple. We can either include a given item in the knapsack, or do not include it.

So, we simply check all the possibilities, and return the maximum profit that we can obtain. We use the given algorithm.

Algorithm

  1. If the remaining weight in knapsack is less than the weight of given item, we can not carry it. Hence, check the other items.
  2. We first try to include the item in the knapsack, and check the profit of remaining items with remaining weights recursively.
  3. Then we exclude the item and check profit of remaining items and weight recursively.
  4. We return the maximum profit earned through both the methods as the answer.

Function

Here is the function using above algorithm

use std::cmp::max;
fn knapsack(w:usize, weights:&Vec<usize>, profits:&Vec<usize>, n:usize) -> usize{
// If we have 0 elements remaining or knapsack filled, return 0
if n==0 || w == 0 {
return 0;
}
// If the nth element has higher weight than available capacity,
// it can not be carried. So, return without including item
if weights[n-1] > w {
return knapsack(w, weights, profits, n-1);
}
// Else, we check by including and excluding the given item
// And return max of it
return max(
// If we exclude item, simply return function for n-1 items
knapsack(w, weights, profits, n-1),
// If we include item, return profit of given item +
// maximum value from given weight for remaining items
profits[n-1] + knapsack(w-weights[n-1], weights, profits, n-1));
}

Input

1 4 5 7
3 6 9 11
11

Output

18

Time Complexity : O( 2n )
Space Complexity : O( n )

( Space complexity includes recursive stack space )

Overlapping Sub-problems

If we have a look carefully on recursive approach, we computed multiple results many times.

For Example : If weights are [1, 2, 3, 4, 5, 6, 7, 8] and W = 20,

We have to compute best answer for n = 2, and W = 5 multiple times, for example, if we select [8, 7] , [8, 4, 3], [7, 5, 3], [6, 5, 4] etc. And it takes 4 recursions each time, when we include and exclude 2 and 1. If we store it in a matrix, we do not have to calculate it again and again.

These are called Overlapping Sub-problems because they are smaller part of large problems, and are computed again and again.

So, we simply calculate them once, and store it in a matrix, and retrieve it when necessary.

This is called Dynamic Programming Approach for the problem.

Memoization ( Top-down DP ) Method

In memoization method, we simply take a DP matrix, and store the computed result.

Algorithm

  1. Initially, set all the elements of dp matrix to -1. ( Do not set to 0 because result of some sub-problems will be 0 itself. They will be computed again and again)
  2. If the profit for given weight and given n is not -1 in the DP matrix, return the value.
  3. Else, compute the profit using recursion and store it in DP Matrix, at given weight and given n.

Function

Here is the function using above algorithm

use std::cmp::max;
fn knapsack(w:usize, weights:&Vec<usize>, profits:&Vec<usize>, n:usize, dp: &mut Vec<Vec<i64>>) -> i64{
// If we have 0 elements remaining or knapsack is already filled, return 0
if n==0 || w == 0 {
dp[n][w] = 0;
return 0;
}
// If already calculated result earlier, return it
if dp[n][w] != -1 {
return dp[n][w];
}
// If the nth element has higher weight than available capacity,
// it can not be carried. So, return without including item
if weights[n-1] > w {
dp[n][w] = knapsack(w, weights, profits, n-1, dp);
return dp[n][w];
}
// Else, we check by including and excluding the given item
// And return max of it
dp[n][w] = max(
// If we exclude item, simply return function for n-1 items
knapsack(w, weights, profits, n-1, dp),
// If we include item, return profit of given item +
// maximum value from given weight for remaining items
profits[n-1] as i64 + knapsack(w-weights[n-1], weights, profits, n-1, dp));
return dp[n][w];
}

With Driver code

use std::cmp::max;
fn knapsack(w:usize, weights:&Vec<usize>, profits:&Vec<usize>, n:usize, dp: &mut Vec<Vec<i64>>) -> i64{
// If we have 0 elements remaining or knapsack filled, return 0
if n==0 || w == 0 {
dp[n][w] = 0;
return 0;
}
// If already calculated result, return it
if dp[n][w] != -1 {
return dp[n][w];
}
// If the nth element has higher weight than available capacity,
// it can not be carried. So, return without including item
if weights[n-1] > w {
dp[n][w] = knapsack(w, weights, profits, n-1, dp);
return dp[n][w];
}
// Else, we check by including and excluding the given item
// And return max of it
dp[n][w] = max(
// If we exclude item, simply return function for n-1 items
knapsack(w, weights, profits, n-1, dp),
// If we include item, return profit of given item +
// maximum value from given weight for remaining items
profits[n-1] as i64 + knapsack(w-weights[n-1], weights, profits, n-1, dp));
return dp[n][w];
}
// Driver Code
use std::io;
fn take_vector() -> Vec<usize> {
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
let arr: Vec<usize> = input.trim().split_whitespace()
.map(|x| x.parse().unwrap()).collect();
return arr;
}
fn main() {
let weights = take_vector();
let profits = take_vector();
let w = take_vector()[0];
// Declare DP Matrix
let mut dp = vec![vec![-1 as i64; w+1]; weights.len()+1];
println!("{}", knapsack(w, &weights, &profits, weights.len() , &mut dp));
}

Input

1 4 5 7
3 6 9 11
11

Output

18

Time Complexity : O( n×w )
Space Complexity : O( n×w )

Tabulation ( Bottom-up DP ) Method

Although time and space complexities of tabulation as well as memoization method are same, tabulation is much more efficient as there are a lot of expensive recursive calls in memoization.

In tabulation method, we make the matrix, and fill it first on the basis of base condition, and then on the basis of previous row.

Algorithm

  1. Initially, set elements of first row and first column as 0, because no profit if n = 0, that is no item left or w = 0, that is knapsack has 0 capacity left.
  2. For all the values till weight of nth item, fill the given row with previous row, because we can not include the item if capacity of knapsack is less than its weight.
  3. Then, fill the remaining row with maximum profit if we include the item, and if we exclude the item.

Function

Here is the function using above algorithm

use std::cmp::max;
fn knapsack(w:usize, weights:&Vec<usize>, profits:&Vec<usize>, n:usize, dp: &mut Vec<Vec<i64>>) -> i64{
// Initially, set the first row and first column to 0
// Because, if n = 0, no item left
// If w = 0, no capacity, hence no more profit.
for i in 0..n+1 { dp[i][0] = 0; }
for i in 0..w+1 { dp[0][i] = 0; }
// Now, we run a loop for all n from 1 to n
for i in 1..n+1 {
// For all weights less than the weight of given item, we can not include it
// So, copy from previous row.
for j in 1..weights[i-1] {
dp[i][j] = dp[i-1][j];
}
// Now, for higher weights, we take max of
// First excluding, than including the item
for j in weights[i-1]..w+1 {
dp[i][j] = max(dp[i-1][j], profits[i-1] as i64 + dp[i-1][j-weights[i-1]]);
}
}
// Finally, return the answer
return dp[n][w];
}

Use the same driver code.

Input

1 4 5 7
3 6 9 11
11

Output

18

Time Complexity : O( n×w )
Space Complexity : O( n×w )

Space Optimized Tabulation Method

If we observe the above tabulation method carefully, we find that for calculating the profit for a given n and w, we only need previous row.

In the above algorithm, step 1 is base case or initialization step, and do not require any other row.

Step 2 and Step 3 requires only previous row.

Hence, we can optimize our space complexity, by storing only the previous row and not the whole matrix.

Function

Here is the function using space optimization of tabulation method.

use std::cmp::max;
fn knapsack(w:usize, weights:&Vec<usize>, profits:&Vec<usize>, n:usize) -> i64{
// Initially, set the first row 0
// Because, if n = 0, no item, hence no profit
// Only O(W) space
let mut previous = vec![0; w+1];
let mut current = vec![0; w+1];
for i in 1..n+1 {
// For all weights less than the weight of given item, we can not include it
// So, copy from previous row.
for j in 1..weights[i-1] {
current[j] = previous[j];
}
// Now, for higher weights, we take max of
// First excluding, than including the item
for j in weights[i-1]..w+1 {
current[j] = max(previous[j], profits[i-1] as i64 + previous[j-weights[i-1]]);
}
// Copy the elements of current to previous
for j in 0..w+1 {
previous[j] = current[j]
}
}
// Finally, return the answer
return current[w];
}

Use the same driver code, except removing DP matrix from driver code.

Input

1 4 5 7
3 6 9 11
11

Output

18

Time Complexity : O( n×w )
Space Complexity : O( w )

Conclusion

0/1 Knapsack Problem is a classical Dynamic Programming problem. In this problem, you have to maximize the profit by using a limited capacity knapsack.

In this article, we saw how to solve the 0/1 Knapsack problem, first using recursion and then using Dynamic Programming, memoization as well as tabulation method, and latter the space optimized tabulation method in Rust Language.

Here is the optimized function for easy access

use std::cmp::max;
fn knapsack(w:usize, weights:&Vec<usize>, profits:&Vec<usize>, n:usize) -> i64{
let mut prev = vec![0; w+1];
let mut curr = vec![0; w+1];
for i in 1..n+1 {
for j in 1..weights[i-1] { curr[j] = prev[j]; }
for j in weights[i-1]..w+1 {
curr[j] = max(prev[j], profits[i-1] as i64 + prev[j-weights[i-1]]);
}
for j in 0..w+1 { prev[j] = curr[j] }
}
return curr[w];
}

Thank You