Matrix Exponentiation is one of the most used techniques in advanced competitive programming. The concept of matrix exponentiation in its most general form is very useful in solving questions that involve calculating the nth term of a linear recurrence relation in time of the order of log(n). It is especially useful in converting a linear O(n) dynamic programming solution into a O(log(n)) solution.
To understand this better, let us consider a very simple example of finding the nth Fibonacci number. This problem can be very easily solved using a linear recurrence. As we all know, by definition of the fibonacci series, Fn = Fn - 1 + Fn - 2. Consider the code below that calculates the nth fibonacci number.
def get_fibonacci(n):
if n <= 0: return -1 # Throw error : Invalid value of n
if n == 1: return 0
if n == 2: return 1
dp = [0] * (n + 1) # Initializing Array
# Setting base case
dp[1] = 0
dp[2] = 1
for i in range(3, n + 1):
dp[i] = dp[i - 1] + dp[i - 2]
return dp[n]
This solution is an iterative dp that runs in linear time ie. O(n). Note that this solution has a space complexity of O(n), but it can be easily converted into O(1) space.
Now at first there does not seem to be any straightforward way to improve the time complexity of this solution. This is where matrix exponentiation comes into the picture. Our goal is to obtain a recurrence relation of the form Fn = P * Fn - 1 where P is a constant matrix and Fn and Fn - 1 are matrices. Let us see what happens if we obtain such a relation.
F2 = P * F1
F3 = P * F2 \ F3 = P * P * F1 \ F3 = P2 * F1
F4 = P * F3 \ F4 = P * P2 * F1 \ F4 = P3 * F1
.\ .\ .
Fn = Pn - 1 * F1
This is a very helpful relation. We have got the nth term of the series in terms of the base matrix F1. Note: This base matrix need not always be n = 1.
You must already be knowing that xn can be calculated in O(log(n)) time using binary exponentiation where x and n are integers. Refer to the code below in case you need a refresher on how that is done. If you have never heard of binary exponentiation, go through this article before continuing.
def power(x, n):
result = 1
while n > 0:
if n % 2 != 0:
result *= x
n = n // 2
x = x * x
return result
A very similar function can be implemented to calculate Pn in O(log(n) * m3) time where P is a square matrix, n is an integer and m is the dimension of P (ie. P is an m x m matrix).
def matrix_power(P, n):
m = len(P)
# Initializing m x m identity matrix
R = [[1 if i == j else 0 for i in range(m)] for j in range(m)]
while n > 0:
if n % 2 != 0:
R = matrix_multiply(P, R)
n = n // 2
P = matrix_multiply(P, P)
return R
def matrix_multiply(A, B):
n = len(A)
m = len(A[0])
q = len(B)
r = len(B[0])
if m != q:
return -1 # Throw error : Incompatible
# Initialzing m x m zero matrix
R = [[0 for i in range(r)] for j in range(n)]
for i in range(n):
for j in range(r):
for k in range(m):
R[i][j] += A[i][k] * B[k][j]
return R
'matrix_multiply' is a function that runs in O(n3). Considering that, the overall time complexity of calculating Pn is O(log(n) * m3).
Now let us get back to the original question. That is, calculating the nth fibonacci number. What we need to do is to get the matrices Fn and P.
We consider Fn to be:
Fn = | fn |
| fn_1 |
where fn is the nth fibonacci number and fn_1 is the (n-1)th fibonacci number.
Now we need to find P such that Fn = P * Fn - 1 \ Using fn = fn - 1 + fn - 2 and fn - 1 = fn - 1 and with the help of basic linear algebra we see that:
P = |1 1|
|1 0|
We will consider the base matrix to be:
F2 = |f2| = |1|
|f1| |0|
The base matrix is Fn when n == 2.
Now we can easily see Fn = Pn - 2 * F2. As shown early Pn - 2 can be calculated in O(log(n) * m3). Here m = 2. Therefore the time complexity of this solution is O(log(n) * 23) which is O(log(n)). The code for this is given below.
def get_fibonacci_matrix_exp(n):
if n <= 0: return -1 # Throw error
if n == 1: return 0
F2 = [[1],
[0]]
P = [[1, 1],
[1, 0]]
Pn_2 = matrix_power(P, n - 2) # Calculating P^(n-2)
Fn = matrix_multiply(Pn_2, F2) # Fn = P^(n-2) * F2
return Fn[0][0]
This is much more efficient solution and scales much better for very large values of n.
Now using this, try to solve the following problem in logarithmic time.
Q. Given a 3 x N rectangle, determine how many ways can we tile the rectangle using 1 x 3 and 3 x 1 tiles.
Like before, first we will come up with a O(n) solution and obtain a linear recurrence relation. In this problem dpn = dpn - 1 + dpn - 3. If it is unclear how we obtain this, I suggest you read this article. We need to express this relation as DPn = P * DPn - 1 , where P, DPn and DPn - 1 are all matrices. Here we take DPn, P and base matrix DP3 (n == 3) as:
DPn = | dpn |
|dpn_1|
|dpn_2|
where dpn is the answer when for when N = n.
P = |1 0 1|
|1 0 0|
|0 1 0|
DP3 = |2|
|0|
|0|
It is easy to see that, DPn = Pn -3 * DP3. As shown before, this can now be solved in O(log(n) * m3) where m is now 3. The code for this is given below.
def get_tiling_count(n):
if n <= 0: return -1 # Throw error : Invalid value of n
if n == 1: return 0
if n == 2: return 0
DP3 = [[2],
[0],
[0]]
P = [[1, 0, 1],
[1, 0, 0],
[0, 1, 0]]
Pn_3 = matrix_power(P, n - 3) # Calculating P^(n-3)
DPn = matrix_multiply(Pn_3, DP3) # DPn = P^(n-3) * DP3
return DPn[0][0]
Here are some other problems that you can try solving to practice this concept.