I’m working on refactoring a big codebase and part of the refactor includes removing some checks. I couldn’t figure out a way to do this efficiently. For example if the old code was something like:
if A:
print("branch 1")
else:
print("branch 2")
I want to delete A and make it always True
. So my new code would look something like:
if True:
print("branch 1")
else:
print("branch 2") # unreachable
Now I want to simplify the code into:
print("branch 1")
A few different scenarios that complicate this:
Scenario 1:
if A and B:
print("branch 1")
else:
print("branch 2")
would become:
if B:
print("branch 1")
else:
print("branch 2")
Scenario 2:
if A:
print("branch 1")
return
print("branch 2")
would become:
print("branch 1")
return
I couldn’t figure out a way to do it efficiently. I tried running a few different linters to see if any of them would suggest the simplification, but got no luck. I am also not good enough doing codemods to write my own script 🙃
2
Pyrefact appears to do this sort of thing and much more. Before I remembered that this tool exists, however, I wrote some code that only does what you describe. Below is the answer I wrote at that time.
This can be achieved using Python’s AST module. The following code will replace all instances of the variable A
with True
(this can be modified/extended using the VARS
variable) and will also simplify boolean expressions (e.g., if True and X:
becomes if X:
) and trivial if
statements (removing cases that will obviously never run). To apply it to a file called file.py
, save this code in a file (which I will assume is called if_simplifier.py
and is in the same directory) and run the following from the command line:
python if_simplifier.py file.py
Two Caveats:
- While I tested this code on several relevant examples, it’s possible I missed an important edge case, so make sure any reformatted code looks correct, and ensure that you test the new code to make sure things are still functional.
- While the way
ast.unparse()
works guarantees that indentation won’t be bad, this code doesn’t remove unreachable code (like theprint("branch 2")
from your Scenario 2), and it does remove some spacing/formatting that will make the resulting code slightly less readable. You can detect unreachable code manually or with a tool like Vulture, and you can reformat your code with a tool like Black (which is also available as an extension for VS Code) to restore readability.
import ast
import sys
from typing import Any
VARS = {"A": True}
class Visitor(ast.NodeTransformer):
def __init__(self, variables: dict[str, Any] | None = None):
super().__init__()
self.variables = variables or {}
def visit_Name(self, node: ast.Name) -> ast.AST:
self.generic_visit(node)
if node.id in self.variables:
return ast.Constant(value=self.variables[node.id])
return node
def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]:
self.generic_visit(node)
if isinstance(node.test, ast.Constant):
if node.test.value:
return node.body
else:
return node.orelse
return node
def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
self.generic_visit(node)
if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.Constant):
return ast.Constant(value=not node.operand.value)
return node
def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
self.generic_visit(node)
if isinstance(node.op, ast.Or):
new_values = []
for value in node.values:
if isinstance(value, ast.Constant):
if value.value:
return value
else:
new_values.append(value)
if len(new_values) == 0:
return ast.Constant(value=False)
elif len(new_values) == 1:
return new_values[0]
else:
node.values = new_values
elif isinstance(node.op, ast.And):
new_values = []
for value in node.values:
if isinstance(value, ast.Constant):
if not value.value:
return value
else:
new_values.append(value)
if len(new_values) == 0:
return ast.Constant(value=True)
elif len(new_values) == 1:
return new_values[0]
else:
node.values = new_values
return node
def simplify_ifs(code: str, variables: dict[str, Any] | None = None) -> str:
tree = ast.parse(code)
visitor = Visitor(variables=variables)
visitor.visit(tree)
return ast.unparse(tree)
def main():
path = sys.argv[1]
code = ""
with open(path, "r") as f:
code = f.read()
with open(path, "w") as f:
f.write(simplify_ifs(code, variables=VARS))
if __name__ == "__main__":
main()