Symbolic matrix differentiation with Sympy !

After a few days spent computing strange Jacobians for my Ph.D. thesis, I figured out that my computer could actually do most of the computations for me : all I needed was an automatic matrix differentiator, i.e. an algorithm that would tell me how a function or matrices F(X,Y,...) varies when the matrices (X,Y...) are changed into (X + \partial X, Y + \partial Y,...) where the (\partial X,\partial Y,...) are small. The very useful Matrix Cookbook gives for instance

\partial (X^T) = (\partial X)^T

\partial (XY) = (\partial X)Y + X (\partial Y)

\partial(X^{-1}) = - X^{-1}(\partial X) X^{-1}

… And so on. We can see that it is not completely unlike classic differentiation, but the non-commutativity slightly complicates things, and the computed formulae can be monstruous : imagine differentiating by hand the hat matrix H = X \left(X^\top \Sigma^{-1} X\right)^{-1} X^\top \Sigma^{-1} , where \Sigma is a constant 😯

Maybe I didn’t look properly, but I didn’t find any program to automatize the rules described above, so I went on building something around Sympy, a young yet sophisticated symbolic mathematics module for the Python programmation language, whose authors claim that it is easily extendable. I put that claim to the test and indeed I found it pretty easy to work my way through.

In what follows I explain what I did, in tutorial-like fashion, with the hope that this can help understanding how one can tweak Sympy to meet one’s own needs. But I am no Sympy expert, and I take any tips and comments in the comments section below ! If you are in a hurry, you’ll find the complete source code of the solution at the bottom of this blog.

Ready ? Code !

First, a matrix is a symbol, which is lucky because sympy already has a Symbol class. You only have to specify that these symbols are not commutative:


def matrices(names):
    ''' Call with  A,B,C = matrix('A B C') '''
    return symbols(names,commutative=False)


Matrix-specific transformations

All the regular operations like addition, substraction, multiplication… are already implemented in the Symbol class, so we can focus on matrix-specific transformations, like the operator \partial and the inverse function X \mapsto X^{-1} :

d = Function("d",commutative=False)
inv = Function("inv",commutative=False)

That was tough ! Now for the transpose function X \mapsto X^T we want the general rules

(X+Y...)^T = X^T + Y^T ...
(XY...)^T = ...Y^TX^T

to automatically apply  when transposing an expression. Here is the code I wrote for that. To understand it you need  to know that when I call some Sympy Function F over a bunch of arguments X,Y,Z, the result is an object of class “F” with a list of arguments (.args)  equal to [X,Y,Z]. Moreover successive additions or multiplications are flattened into one single object: for instance
X^2 + 2*X*Y + 3*Z + 5
will be an object of class “Add” with arguments
[X^2,2*X*Y,3*Z,5] ,
while
X^2*Y*X*Z
will be of class “Mul” with arguments
[X^2,Y,X,Z] .

class t(Function):
    ''' The transposition, with special rules
        t(A+B) = t(A) + t(B) and t(AB) = t(B)t(A) '''
    is_commutative = False
    def __new__(cls,arg):
        if arg.is_Add:
            return Add(*[t(A) for A in arg.args])
        elif arg.is_Mul:
            L = len(arg.args)
            return Mul(*[t(arg.args[L-i-1]) for i in range(L)])
        else:
            return Function.__new__(cls,arg)

That’s enough transformations for now, let us explain to sympy how to differentiate a matrix expression !

Matrix differentiation

We write all the just enough differentiation rules of the Matrix Cookbook into one dictionnary:

MATRIX_DIFF_RULES = { 
		# e =expression, s = a list of symbols respsect to which
		# we want to differentiate
		
		Symbol : lambda e,s : d(e) if (e in s) else 0,
		Add :  lambda e,s : Add(*[matDiff(arg,s) for arg in e.args]),
		Mul :  lambda e,s : Mul(matDiff(e.args[0],s),Mul(*e.args[1:]))
					  +  Mul(e.args[0],matDiff(Mul(*e.args[1:]),s)) ,
		t :   lambda e,s : t( matDiff(e.args[0],s) ),
		inv : lambda e,s : - e * matDiff(e.args[0],s) * e
}

and apply them recursively to the expression we want to treat :

def matDiff(expr,symbols):
    if expr.__class__ in MATRIX_DIFF_RULES.keys():
        return  MATRIX_DIFF_RULES[expr.__class__](expr,symbols)
    else:
        return 0

Simple as pie, and we are done ! Let’s play around with our new toy:

X,S = matrices("X S")
H= X*inv(t(X)*inv(S)*X)*t(X)*inv(S)
print  mdiff(H,X)
>>> X*(inv(t(X)*inv(S)*X)*t(d(X))*inv(S) - inv(t(X)*inv(S)*X)*(t(X)*inv(S)*d(X) + t(d(X))*inv(S)*X)*inv(t(X)*inv(S)*X)*t(X)*inv(S)) + d(X)*inv(t(X)*inv(S)*X)*t(X)*inv(S)

It works ! But we must concede that it is barely readable… so let us put some style in these expressions !

Cosmetics

You don’t need latex to have nice-looking formulae. Here is what we want to say to sympy:

  • Don’t write inv(X), write X¯¹
  • Don’t write t(X), write X’
  • Don’t write d(X), write ∂X
  • In a general way, don’t put parenthesis if the transformation applies to one symbol only.

The programmers of Sympy have made it all easy to customize the default printing method. You just write a printing method with complements for the functions of our own. First we need to include this line at the top of our source file to explain that we are going to use strange characters:

# -*- coding: utf-8 -*-

Then we write our own printer class that herits from the default one but has a few more methods :

class matStrPrinter(StrPrinter):
    ''' Nice printing for console mode : X¯¹, X', ∂X '''
    
    def _print_inv(self, expr):
		if expr.args[0].is_Symbol:
			return  self._print(expr.args[0]) +'¯¹'
		else:
			return '(' +  self._print(expr.args[0]) + ')¯¹'
    
    def _print_t(self, expr):
		return  self._print(expr.args[0]) +"'"
    
    def _print_d(self, expr):
		if expr.args[0].is_Symbol:
			return '∂'+  self._print(expr.args[0])
		else:
			return '∂('+  self._print(expr.args[0]) +')'

A this point the people who wrote the Sympy documentation want us to directly replace the default printer with this command:

Basic.__str__ = lambda self: matStrPrinter().doprint(self)

It indeed seems to be the only way to ensure that all the expressions in the recursion will print well with the standard “print” method, but it is a little to definitive for me (what if I have several customized printers to use in the same script ?). Moreover I don’t like the ‘*’ between the matrices and I didn’t find any smart way to put them away. For all these reasons I wrote my own matrix printing command:

def matPrint(m):
	mem = Basic.__str__ 
	Basic.__str__ = lambda self: matStrPrinter().doprint(self)
	print str(m).replace('*','')
	Basic.__str__ = mem

Now let’s try again:

matPrint(  matDiff(H,X) )
>>> X((X'S¯¹X)¯¹∂X'S¯¹ - (X'S¯¹X)¯¹(X'S¯¹∂X + ∂X'S¯¹X)(X'S¯¹X)¯¹X'S¯¹) + ∂X(X'S¯¹X)¯¹X'S¯¹

Way better ! Notice, however that the first X is a factor or two terms of a sum, one of these being in turn a factor involving a sum. We can break these imbricated expressions into a sum of products (which is simpler to analyze) by calling twice Sympy’s ‘expand’ function :

matPrint(  expand ( expand ( matDiff(H,X) ) ) )
>>> X(X'S¯¹X)¯¹∂X'S¯¹ + ∂X(X'S¯¹X)¯¹X'S¯¹ - X(X'S¯¹X)¯¹X'S¯¹∂X(X'S¯¹X)¯¹X'S¯¹ - X(X'S¯¹X)¯¹∂X'S¯¹X(X'S¯¹X)¯¹X'S¯¹

That’s nice enough for the console ! Now what if I want to report this fundamental result to the scientific community ? Let’s see how to generate the LaTeX code to embed the formula in a document.

Automatic LaTeX code generation

One of the nice features of Sympy is its ability to generate LaTeX code (and even display compiled formulae). For matrix computations we would like Sympy to follow these rules:

  • X^{-1} writes X^{-1}
  • X^{T} writes X^{T}
  • \partial X writes \partial X
  • In a general way, don’t put parenthesis if the transformation applies to one symbol only.

Like in the previous section we create a new ‘printer’ class with these features, and then write a function inpired by Sympy’s latex method (which is simply called latex() ):

class matLatPrinter(LatexPrinter):
    ''' Printing instructions for latex : X^{-1},  X^T, \partial X '''
	
    def _print_inv(self, expr):
        if expr.args[0].is_Symbol:
            return self._print(expr.args[0]) +'^{-1}'
        else:
            return '(' + self._print(expr.args[0]) + ')^{-1}'
    def _print_t(self, expr):
		return  self._print(expr.args[0]) +'^T'
    
    def _print_d(self, expr):
		if expr.args[0].is_Symbol:
			return '\partial '+ self._print(expr.args[0])
		else:
			return '\partial ('+ self._print(expr.args[0]) +')'

def matLatex(expr, profile=None, **kargs):
    if profile is not None:
        profile.update(kargs)
    else:
        profile = kargs
    return matLatPrinter(profile).doprint(expr)

We now try it on the derivative of H:

print matLatex( matDiff(H,X) )
>>> $X \left((X^T S^{-1} X)^{-1} \partial X^T S^{-1} - (X^T S^{-1} X)^{-1} \left(X^T S^{-1} \partial X + \partial X^T S^{-1} X\right) (X^T S^{-1} X)^{-1} X^T S^{-1}\right) + \partial X (X^T S^{-1} X)^{-1} X^T S^{-1}$

which once compiled yields
X \left((X^T S^{-1} X)^{-1} \partial X^T S^{-1} - (X^T S^{-1} X)^{-1} \left(X^T S^{-1} \partial X + \partial X^T S^{-1} X\right) (X^T S^{-1} X)^{-1} X^T S^{-1}\right) + \partial X (X^T S^{-1} X)^{-1} X^T S^{-1}

Yeah !

Source Code

I hope this helped. If you just came here to get a matrix differentiator, here is the code with the example. Have fun 🙂


# Declaration

# -*- coding: utf-8 -*-

#----------------------------------------------------------------------
#
# FUNCTIONS FOR THE AUTOMATIC DIFFERENTIATION  OF MATRICES WITH SYMPY
# 
#----------------------------------------------------------------------

from sympy import *
from sympy.printing.str import StrPrinter
from sympy.printing.latex import LatexPrinter



#####  M  E  T  H  O  D  S



def matrices(names):
    ''' Call with  A,B,C = matrix('A B C') '''
    return symbols(names,commutative=False)


# Transformations

d = Function("d",commutative=False)
inv = Function("inv",commutative=False)

class t(Function):
    ''' The transposition, with special rules
        t(A+B) = t(A) + t(B) and t(AB) = t(B)t(A) '''
    is_commutative = False
    def __new__(cls,arg):
        if arg.is_Add:
            return Add(*[t(A) for A in arg.args])
        elif arg.is_Mul:
            L = len(arg.args)
            return Mul(*[t(arg.args[L-i-1]) for i in range(L)])
        else:
            return Function.__new__(cls,arg)


# Differentiation

MATRIX_DIFF_RULES = { 
		
		# e =expression, s = a list of symbols respsect to which
		# we want to differentiate
		
		Symbol : lambda e,s : d(e) if (e in s) else 0,
		Add :  lambda e,s : Add(*[matDiff(arg,s) for arg in e.args]),
		Mul :  lambda e,s : Mul(matDiff(e.args[0],s),Mul(*e.args[1:]))
					  +  Mul(e.args[0],matDiff(Mul(*e.args[1:]),s)) ,
		t :   lambda e,s : t( matDiff(e.args[0],s) ),
		inv : lambda e,s : - e * matDiff(e.args[0],s) * e
}

def matDiff(expr,symbols):
    if expr.__class__ in MATRIX_DIFF_RULES.keys():
        return  MATRIX_DIFF_RULES[expr.__class__](expr,symbols)
    else:
        return 0



#####  C  O  S  M  E  T  I  C  S


# Console mode

class matStrPrinter(StrPrinter):
    ''' Nice printing for console mode : X¯¹, X', ∂X '''
    
    def _print_inv(self, expr):
		if expr.args[0].is_Symbol:
			return  self._print(expr.args[0]) +'¯¹'
		else:
			return '(' +  self._print(expr.args[0]) + ')¯¹'
    
    def _print_t(self, expr):
		return  self._print(expr.args[0]) +"'"
    
    def _print_d(self, expr):
		if expr.args[0].is_Symbol:
			return '∂'+  self._print(expr.args[0])
		else:
			return '∂('+  self._print(expr.args[0]) +')'	

def matPrint(m):
	mem = Basic.__str__ 
	Basic.__str__ = lambda self: matStrPrinter().doprint(self)
	print str(m).replace('*','')
	Basic.__str__ = mem


# Latex mode

class matLatPrinter(LatexPrinter):
    ''' Printing instructions for latex : X^{-1},  X^T, \partial X '''
	
    def _print_inv(self, expr):
        if expr.args[0].is_Symbol:
            return self._print(expr.args[0]) +'^{-1}'
        else:
            return '(' + self._print(expr.args[0]) + ')^{-1}'
    def _print_t(self, expr):
		return  self._print(expr.args[0]) +'^T'
    
    def _print_d(self, expr):
		if expr.args[0].is_Symbol:
			return '\partial '+ self._print(expr.args[0])
		else:
			return '\partial ('+ self._print(expr.args[0]) +')'

def matLatex(expr, profile=None, **kargs):
    if profile is not None:
        profile.update(kargs)
    else:
        profile = kargs
    return matLatPrinter(profile).doprint(expr)



#####    T  E  S  T  S


X,S = matrices("X S")
H= X*inv(t(X)*inv(S)*X)*t(X)*inv(S)

matPrint(  expand( expand( matDiff(H,X) ) ) )

print matLatex( matDiff(H,X) )

5 comments on “Symbolic matrix differentiation with Sympy !

  1. Juanlu001 says:

    Just wow! This is a great article on how to extend the possibilities of SymPy too. I read Matthew Rocklin (http://sympystats.wordpress.com/2011/07/19/matrix-expressions/) was interested on symbolic manipulation with SymPy too, and was working on it.

    Good article! 😀

  2. Saullo says:

    Great article!

  3. Aaron Meurer says:

    This sounds like the sort of thing that we want implemented in SymPy itself. See https://code.google.com/p/sympy/issues/detail?id=2759. Pull requests welcome!

    • Valentin says:

      Sure. There has already been an attempt to merge these rules but apparently there have been difficulties and I don’t really get the final decision:

      see https://github.com/sympy/sympy/pull/1275

      Do I reopen the issue on Github ?

      • Aaron Meurer says:

        I don’t remember the details, but I believe that pull request was closed because it wasn’t done correctly (mathematically). If your code produces correct results, as per the matrix cookbook, then it’s definitely something we want to include. It may need restructuring to use the existing MatrixSymbol framework.

        But I think the biggest problem was that none of us really had a good understanding of the mathematics, or the time to implement it if we did. It sounds like you use this sort of math, so you probably understand how it should work, so if you have the time to implement it, that would be awesome.

Leave a reply to Juanlu001 Cancel reply