How to write a simple math interpreter in Python

Not long ago, I had to write a certain feature for my project. It involved parsing a mathematical expression from plaintext and evaluating it. This feature had to work with basic numerical expressions like 2 + 3, support context to use variables: apples + 2 * oranges, and parentheses: (2 + 3) - apples.

The obvious (and least appropriate) solution was to use eval. This standard function takes a string and runs it, treating it like Python code. This is considered to be very unsafe: eval an execute arbitrary code, which makes it a potential security risk, especially if the input comes from untrusted sources. Malicious users could inject harmful code that could lead to unintended consequences, such as executing system commands or accessing sensitive information.

In this article, I will walk you through my safe and extensible implementation using Python’s ast module.

Revisiting the spec

Before we continue, I would like to properly list the features that our math interpreter must implement:

  • Evaluate literals: that means that it must correctly parse numbers, so an input string “3.14” returns the number 3.14.
  • Evaluate variables: provided a context dictionary like {"apples": 4}, and an input string apples, must return the number 4.
  • Support unary operators: must parse a string like -5 into the negative number -5.
  • Support binary operators like +,-,*, and /. A string like 4 + 2 turns into the number 6.
  • Support correct operator precedence and parenthesis: string 6 - 2 * 2 must evaluate to 2, not 8.
  • Mix and match all of the above: must correctly parse and evaluate complex strings like apples + (1 / 3) * oranges - 5.00001.

Other implementations I considered

There are several different possible implementations that one can use:

  • eval: as I said before, this is very unsafe and not viable for any project that is used by other people.
  • Third-party packages like numexp. While this is a viable and easy solution, I dislike it for 2 reasons: it adds support overhead to your project (need to keep track of one more library, update it and find a new one if this goes unsupported), and they are heavy, dependent on more libraries (like numpy) and contain way more features than I needed.
  • ast.literal_eval built-in function. This function can evaluate the simplest of expressions, but that was not enough to satisfy my requirements. For example, it does not support multiplication and division.

Writing tests

I like to follow TDD (Test Driven Development) in all my projects, so even in this tutorial, we start with tests. I will use the unittest library to write my tests. Firstly, let us write a function that accepts an expression, variables dictionary, and expected answer, and asserts that our interpreter returns the correct value:

# tests.py
import unittest
from interpreter import MathInterpreter


class MathInterpreterTestCase(unittest.TestCase):

    def perform_test(self, expression: str, variables: dict[str, int | float], expected: int | float):
        result = MathInterpreter.eval(expression, variables)
        self.assertEqual(result, expected, f'Test failed for expression: {expression}')


if __name__ == '__main__':
    unittest.main()

This function will make writing many test cases easier. We will write the test data an an array of tuples. Tuples will be in the form (expression, variables, expected):

class MathInterpreterTestCase(unittest.TestCase):
    
    test_data = [
        ("5", {}, 5),
        ("3.14", {}, 3.14)
    ]
    
    ...

For now, I have included the most basic test to get us started: literal parsing. Lastly, we write the test function:

class MathInterpreterTestCase(unittest.TestCase):
    
    ...

    def test_interpreter(self):
        for test_case in self.test_data:
            expression, variables, expected = test_case
            with self.subTest(f'{expression}={expression}'):
                self.perform_test(expression, variables, expected)

This function takes the test cases and runs a test for each one. Now, if we want to test more features, we just add cases to the test_cases list!

If you try running these tests now, you will get an import error as we did not create the interpreter class yet:

...
ModuleNotFoundError: No module named 'interpreter'

We will now move on to the implementation.

Implementing literals

Literal evaluation is the easiest one there is. Let’s first create the interpreter module to fix the import error when running tests:

# interpreter.py

class MathInterpreter:

    @staticmethod
    def eval(expression: str, variables: dict[str, int | float]) -> int | float:
        return 0

This static function will accept a string expression and a context dictionary, and produce a numerical result. If you run the tests now, you can see that we fixed the import error and are now getting meaningful errors:

...
AssertionError: 0 != 5 : Test failed for expression: 5
...
AssertionError: 0 != 3.14 : Test failed for expression: 3.14

Now it is finally time to implement basic parsing. We will do that using Python’s built-in ast module.

AST stands for Abstract Syntax Tree, and is used to parse Python code into a programmatically-accessible tree of nodes. This tool is used widely by many different metaprogramming applications, such as static code analysis and cusom domain languages.

The ast module exports one function that is interesting to us: parse. This function will accept valid Python code and parse it into a tree of nodes, which is precisely what we need. Use this function to parse the expression variable:

class MathInterpreter:

    @staticmethod
    def eval(expression: str, variables: dict[str, int | float]) -> int | float:
        root_node = ast.parse(expression, mode='eval')
        return MathInterpreter.__walk(root_node, variables)

    @staticmethod
    def __walk(node: ast.AST, variables: dict[str, int | float]) -> int | float:
        match node:
            case ast.Expression():
                return MathInterpreter.__walk(node.body, variables)
            case ast.Num():
                return node.n
            case _:
                raise TypeError()

In the eval function, we use the ast.parse function to get the root node of the syntax tree. We pass it to the recursive function __walk, which will return the correct values based on the type of nodes. This design makes it easy to extend our interpreter in the future.

If you are at any time confused by the usage of ast library, I suggest you to consult the docs.

The tests should now pass:

Ran 1 test in 0.004s

OK

And we can move on to implementing arithmetic operations.

Implementing arithmetic operations

As always, before writing the implementation we have to write the tests. Thankfully, due to the extensible way our tests are written, this is easy enough. Add some test cases to the test_data list:

test_data = [
        # literals
        ("5", {}, 5),
        ("3.14", {}, 3.14),
        
        # arithmetics
        ("2 + 2", {}, 4),
        ("2 + 3", {}, 5),
        ("0 - 4", {}, -4),
        ("1.0004 + 5.01", {}, 6.0104),
        ("5 / 2", {}, 2.5),
        ("-4", {}, -4)
    ]

These tests should fail now due to TypeError, which we throw on unsupported nodes (arithmetic operators, in this case). Let’s fix them:

@staticmethod
    def __walk(node: ast.AST, variables: dict[str, int | float]) -> int | float:
        match node:
            case ast.Expression():
                return MathInterpreter.__walk(node.body, variables)
            case ast.Num():
                return node.n
            case ast.BinOp():
                left, right, op = node.left, node.right, node.op
                match op:
                    case ast.Add():
                        return MathInterpreter.__walk(left, variables) + MathInterpreter.__walk(right, variables)
                    case ast.Sub():
                        return MathInterpreter.__walk(left, variables) - MathInterpreter.__walk(right, variables)
                    case ast.Mult():
                        return MathInterpreter.__walk(left, variables) * MathInterpreter.__walk(right, variables)
                    case ast.Div():
                        return MathInterpreter.__walk(left, variables) / MathInterpreter.__walk(right, variables)
                    case _:
                        raise TypeError()
            case ast.UnaryOp():
                operand, op = node.operand, node.op
                match op:
                    case ast.USub():
                        return -1 * MathInterpreter.__walk(operand, variables)
                    case _:
                        raise TypeError()
            case _:
                raise TypeError()

You can see we have added more cases to our __walk method. Now we evaluate ast.BinOp and ast.UnaryOp nodes, which represent binary operations and unary operations, respectively. I have written it in a very verbose way to make it clear, but let us refactor it with less repetition:

# interpreter.py
import ast
import operator as op


class MathInterpreter:
    __operators_map = {
        ast.Add: op.add,
        ast.Sub: op.sub,
        ast.USub: op.neg,
        ast.Mult: op.mul,
        ast.Div: op.truediv,
    }

    @staticmethod
    def eval(expression: str, variables: dict[str, int | float]) -> int | float:
        root_node = ast.parse(expression, mode='eval')
        return MathInterpreter.__walk(root_node, variables)

    @staticmethod
    def __walk(node: ast.AST, variables: dict[str, int | float]) -> int | float:
        match node:
            case ast.Expression():
                return MathInterpreter.__walk(node.body, variables)
            case ast.Num():
                return node.n
            case ast.BinOp():
                left, right, op = node.left, node.right, node.op
                method = MathInterpreter.__operators_map[type(op)]
                return method(MathInterpreter.__walk(left, variables), MathInterpreter.__walk(right, variables))
            case ast.UnaryOp():
                operand, op = node.operand, node.op
                method = MathInterpreter.__operators_map[type(op)]
                return method(MathInterpreter.__walk(operand, variables))
            case _:
                raise TypeError()

You can see that we moved the operator mappings to the __operators_map dictionary to reduce repetition. If you run the tests now, everything should pass:

Ran 1 test in 0.002s

OK

We can move on to implementing variable context.

Implementing variables

We start with, you guessed it, tests! Add some more cases to the test_data list:

test_data = [
        # ...
        # variables
        ("a", {"a": 5}, 5),
        ("pi", {"pi": 3.14}, 3.14),
        ("a + 5", {"a": 5}, 10),
        ("a / b", {"a": 5, "b": 2}, 2.5),
        ("a + b - c", {"a": 5, "b": 1.1, "c": 2.2}, 3.9),
    ]

And they should fail with TypeError, as the nodes of the Name type are not yet recognized. Let us fix that:

@staticmethod
    def __walk(node: ast.AST, variables: dict[str, int | float]) -> int | float:
        match node:
            # ...
            case ast.Name():
                id = node.id
                return variables[id]
            case _:
                raise TypeError()

This code should be pretty straightforward. The node corresponding to variable names is of type ast.Name, and we substitute it for the variable value from the dictionary. Now the tests should pass:

Ran 1 test in 0.003s

OK

This concludes our implementation of a math interpreter in Python.

Further improvements

While I will not cover that here (unless someone in the comments asks to), here are ways you could extend this interpreter:

  • Error handling. You can add proper error handling to raise exceptions when variables are of the wrong type, expression is incorrect, variables are missing, and so on.
  • Function support. Using ast.Call, you can implement basic math functions such as round, sqrt, and so on.

Closing notes

I hope you found this article useful. Please let me know in the comments what you think about it, or tell me about your implementation that is different from mine. Cheers!

Get new content delivered to your mailbox:

leave a comment