Skip to content

Notes on ast library in python

Published: at 09:16 AM

Table of contents

Open Table of contents

getting to and from ASTs

build an ast from code stored as a string, use ast.parse(); to turn the ast into executable code, pass it to compile().

tree = ast.parse("print('hello world')")
tree
<ast.Module at 0x7fdcbffc54e0>
exec(compile(tree, filename="<ast>", mode="exec"))
hello world

modes

python code can be compiled in three modes. The root of the ast depends on the mode parameter you pass to ast.parse(), and it must correspond to the mode parameter when you call compile().

fixing locations

To compile an AST, every node must have lineno and col_offset attrs. Nodes produced by parsing regular code already have these, but nodes you create programmatically don’t. There are a few helper functions for this:

going backwards

meet the nodes

An ast represents each element in your code as an object. For example, the code a+1 is a BinOp, with a Name on the left, a Num on the right, and an Add operator.

Literals

class Constant(value, kind): A constant. The value attr holds the Python object it represents. This can be simple types such as a number, string or None, but also immutable container types(tuples and frozensets) if all of their elements are constant.

others omitted.

Variables

class Name(id, ctx): A var name. id holds the name as a string, and ctx is one of the following types: Load Store Del.

class Starred(value, ctx): A *var variable reference. value holds the var, typically a Name node. Note this isn’t used to define a function with *args.

Exprs

class Expr(value): when an expr appears as a statement by itself, with its return value not used or stored, it’s wrapped in this container. value holds one of the other nodes in this section, or a literal, a Name, a Lamba, or a Yield or YieldFrom node.

class NamedExpr(target, value): used to bind an expression to a name using the := operator. target holds a Name which is the name the expr is bound to. Note that the ctx of the Name should be set to Store. value is any node valid as the value of Expr.

class UnaryOp(op, operand): op is the operator, and operand is any expression node. operator: UAdd USub Not Invert(~ operator)

class BinOp(left, op, right): Add Sub

class Compare(left, ops, comparators): Eq NotEq Lt

class Call(func, args, keywords, starargs, kwargs): A function call. func is the function, which will often be a Name or Attribute object. args holds a list of the arguments passed by position; keywords holds a list of keyword objects representing arguments passed by keyword.

class keyword(arg, value)

class IfExp(test, body, orelse)

class Attribute(value, attr, ctx)

Subscripting

class Subscript(value, slice, ctx): A subscript, such as l[1]. value is the object, often a Name. slice is one of Index(value), Slice(lower, upper, step) or ExtSlice(dims) where dims holds a list of Slice and Index nodes.

Comprehensions

class ListComp(elt, generators)

statements

class Assign(targets, value, type_comment): An assignment. targets is a list of nodes, and value is a single node. type_comment is optional. Note in unpacking case, multiple targets are combined into a tuple. In multiple assignment like a=b=1, targets will be the list of nodes.

class AugAssign(target, op, value)

class Raise(exc, cause):

Imports

class Import(names): names is a list of alias nodes.

class ImportFrom(module, names, level): represents from x import y. module is a raw string of the ‘from’ name.

class alias(name, asname): both parameters are raw strings of the names.

Control flow

class If(test, body, orelse): An if statement. test holds a single node, such as a Compare node. body and orelse each hold a list of nodes. elif will be appeared at orelse.

class For(target, iter, body, orelse, type_comment)

class While(test, body, orelse)

class Break

class Continue

function and class definitions

class FunctionDef(name, args, body, decorator_list, returns, type_comment):

class Lamba(args, body)

class arguments(posonlyargs, args, vararg, kwonlyargs, kw_defaults, kwarg, defaults)

class arg(arg, annotation, type_comment)

class Return(value)

class ClassDef(name, bases, keywords, starargs, kwargs, body, decorator_list)

top level nodes

class Module(stmt* body, type_ignore *type_ignores)

class Interactive(stmt* body)

class Expression(expr body)

Working on the Tree

ast.NodeVisitor is the primary tool for ‘scanning’ the tree. For example, the following visitor will print the names of any functions defined in the given code, including methods and functions defined within other functions.

Note if you want child nodes to be visited, call self.generic_visit(node) in the methods you override.

class FuncLister(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print(node.name)
        self.generic_visit(node)
raw_code = """
def foo():
    def bar():
        pass
    pass
print("hello")    
"""
tree = ast.parse(raw_code)
FuncLister().visit(tree)
foo
bar

Alternatively, you can run through a list of all the nodes in the tree using ast.walk(). There are no guarantees about the order in which nodes will appear. The following example again prints the names of any functions defined:

for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        print(node.name)
foo
bar
print(ast.dump(tree, indent=2))
Module(
  body=[
    FunctionDef(
      name='foo',
      args=arguments(
        posonlyargs=[],
        args=[],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[]),
      body=[
        FunctionDef(
          name='bar',
          args=arguments(
            posonlyargs=[],
            args=[],
            kwonlyargs=[],
            kw_defaults=[],
            defaults=[]),
          body=[
            Pass()],
          decorator_list=[]),
        Pass()],
      decorator_list=[]),
    Expr(
      value=Call(
        func=Name(id='print', ctx=Load()),
        args=[
          Constant(value='hello')],
        keywords=[]))],
  type_ignores=[])

You can also get the direct children of a node, using ast.iter_child_nodes(). Remember that many nodes have children in several sections: for example, an If has a node in the test field, and list of nodes in body and orelse. ast.iter_child_nodes() will go through all of these.

Finally, you can navigate directly, using the attrs of the nodes.

Inspecting nodes

The ast module has a couple of funcs for insepcting nodes:

modifying the tree

The key tool is ast.NodeTransformer.

class RewriteName(ast.NodeTransformer):

    def visit_Name(self, node):
        return ast.copy_location(ast.Subscript(
            value=ast.Name(id='data', ctx=ast.Load()),
            slice=ast.Constant(node.id),
            ctx=node.ctx
        ), node)
tree = RewriteName().visit(tree)
ast.dump(tree)
"Module(body=[FunctionDef(name='foo', args=arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[FunctionDef(name='bar', args=arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[Pass()], decorator_list=[]), Pass()], decorator_list=[]), Expr(value=Call(func=Subscript(value=Name(id='data', ctx=Load()), slice=Constant(value='print'), ctx=Load()), args=[Constant(value='hello')], keywords=[]))], type_ignores=[])"

examples of working with ASTs

Wrapping integers

class IntegerWrapper(ast.NodeTransformer):

    def visit_Num(self, node):
        if isinstance(node.n, int):
            return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()), args=[node], keywords=[])
        return node
    
tree = ast.parse("1/3")

tree = IntegerWrapper().visit(tree)

ast.fix_missing_locations(tree)

<ast.Module at 0x7fdcbffc4c10>
ast.dump(tree)
"Module(body=[Expr(value=BinOp(left=Call(func=Name(id='Integer', ctx=Load()), args=[Constant(value=1)], keywords=[]), op=Div(), right=Call(func=Name(id='Integer', ctx=Load()), args=[Constant(value=3)], keywords=[])))], type_ignores=[])"

Simple test framework

class AssertCmpTransformer(ast.NodeTransformer):
    def visit_Assert(self, node):
        if isinstance(node.test, ast.Compare) and \
                len(node.test.ops) == 1 and \
                isinstance(node.test.ops[0], ast.Eq):
            call = ast.Call(func=ast.Name(id='assert_equal', ctx=ast.Load()),
                            args=[node.test.left, node.test.comparators[0]],
                            keywords=[])
            # Wrap the call in an Expr node, because the return value isn't used.
            newnode = ast.Expr(value=call)
            ast.copy_location(newnode, node)
            ast.fix_missing_locations(newnode)
            return newnode

        # Remember to return the original node if we don't want to change it.
        return node

Previous Post
Managing File Permissions with Docker Containers
Next Post
Lecture notes on constructive logic 02