Continuing from the previous chapter, refactoring starts now. In this article, we discussed the meaning, significance and timing of refactoring, and explained the importance of testing. Starting from this article, we will introduce the specific techniques of refactoring. The first ones are some general techniques.
1 Extract function
Function extraction is probably the most common refactoring technique. We all don’t want to stuff a lot of things into one function, making one function do a lot of things, which will be very tiring to read or modify (I once saw a 600+ lines of code in ByteDance) function function).
In theory, functions should be dedicated to specific tasks, and each function should only do one thing. Based on this idea, we can divide and refine a function according to the characteristics and functions of its code segment, and call it in the form of nested functions. Of course, you can decide the granularity of the division yourself. In the book “Refactoring”, the author believes that functions with more than 6 lines of code will smell bad.
Before refactoring:
import datetime classInvoice(): def __init__(self, orders, customer): self.orders = orders self.customer = customer self.dueDate = "" def printOwing(invoice): outstanding = 0 print("************************") print("**** Customer Owes ****") print("************************") # calculate outstanding for o in invoice.orders: outstanding + = o["amount"] # record due date now = datetime.datetime.now() invoice.dueDate = now + datetime.timedelta(days=30) # print details print(f'name: {<!-- -->invoice.customer}') print(f'amount: {<!-- -->outstanding}') print(f'due: {<!-- -->datetime.datetime.strftime(invoice.dueDate, "%Y-%m-%d %H:%M:%S")}') invoice = Invoice( [{<!-- -->"amount": 1}, {<!-- -->"amount": 2}, {<!-- -->"amount": 3}], "zhangsan" ) printOwing(invoice)
After refactoring:
import datetime classInvoice(): def __init__(self, orders, customer): self.orders = orders self.customer = customer self.dueDate = "" def printBanner(): print("************************") print("**** Customer Owes ****") print("************************") def printOwing(invoice): outstanding = 0 printBanner() # calculate outstanding for o in invoice.orders: outstanding + = o["amount"] # record due date now = datetime.datetime.now() invoice.dueDate = now + datetime.timedelta(days=30) # print details print(f'name: {<!-- -->invoice.customer}') print(f'amount: {<!-- -->outstanding}') print(f'due: {<!-- -->datetime.datetime.strftime(invoice.dueDate, "%Y-%m-%d %H:%M:%S")}') invoice = Invoice( [{<!-- -->"amount": 1}, {<!-- -->"amount": 2}, {<!-- -->"amount": 3}], "zhangsan" ) printOwing(invoice)
We have extracted the static print statement. Similarly, we can also extract the print detail part, but this part has parameters, and the parameters can be entered as parameters of the refined function.
Every time you make a refactoring change, you must retest it.
import datetime classInvoice(): def __init__(self, orders, customer): self.orders = orders self.customer = customer self.dueDate = "" def printBanner(): print("**********************") print("**** Customer Owes ****") print("**********************") def printDetails(invoice, outstanding): print(f'name: {<!-- -->invoice.customer}') print(f'amount: {<!-- -->outstanding}') print(f'due: {<!-- -->datetime.datetime.strftime(invoice.dueDate, "%Y-%m-%d %H:%M:%S")}') def printOwing(invoice): outstanding = 0 printBanner() # calculate outstanding for o in invoice.orders: outstanding + = o["amount"] # record due date now = datetime.datetime.now() invoice.dueDate = now + datetime.timedelta(days=30) # print details printDetails(invoice, outstanding) invoice = Invoice( [{<!-- -->"amount": 1}, {<!-- -->"amount": 2}, {<!-- -->"amount": 3}], "zhangsan" ) printOwing(invoice)
Similarly, record due date can also be retrieved in the same way
import datetime classInvoice(): def __init__(self, orders, customer): self.orders = orders self.customer = customer self.dueDate = "" def printBanner(): print("**********************") print("**** Customer Owes ****") print("************************") def printDetails(invoice, outstanding): print(f'name: {<!-- -->invoice.customer}') print(f'amount: {<!-- -->outstanding}') print(f'due: {<!-- -->datetime.datetime.strftime(invoice.dueDate, "%Y-%m-%d %H:%M:%S")}') def recordDueDate(invoice): now = datetime.datetime.now() invoice.dueDate = now + datetime.timedelta(days=30) def printOwing(invoice): outstanding = 0 printBanner() # calculate outstanding for o in invoice.orders: outstanding + = o["amount"] # record due date recordDueDate(invoice) # print details printDetails(invoice, outstanding) invoice = Invoice( [{<!-- -->"amount": 1}, {<!-- -->"amount": 2}, {<!-- -->"amount": 3}], "zhangsan" ) printOwing(invoice)
The middle section of calculate outstanding is the assignment of the outstanding local variable. How to refine this section?
It’s very simple. You just need to move outstanding next to the statement that operates it, then turn it into a temporary variable, process it as a temporary variable in the function and return it.
import datetime classInvoice(): def __init__(self, orders, customer): self.orders = orders self.customer = customer self.dueDate = "" def printBanner(): print("**********************") print("**** Customer Owes ****") print("**********************") def printDetails(invoice, outstanding): print(f'name: {<!-- -->invoice.customer}') print(f'amount: {<!-- -->outstanding}') print(f'due: {<!-- -->datetime.datetime.strftime(invoice.dueDate, "%Y-%m-%d %H:%M:%S")}') def recordDueDate(invoice): now = datetime.datetime.now() invoice.dueDate = now + datetime.timedelta(days=30) def calculateOutstanding(invoice): outstanding = 0 for o in invoice.orders: outstanding + = o["amount"] return outstanding def printOwing(invoice): printBanner() # calculate outstanding outstanding = calculateOutstanding(invoice) # record due date recordDueDate(invoice) # print details printDetails(invoice, outstanding) invoice = Invoice( [{<!-- -->"amount": 1}, {<!-- -->"amount": 2}, {<!-- -->"amount": 3}], "zhangsan" ) printOwing(invoice)
At this point, we have split the original function printOwing into five functions. Each function only performs a specific function. printOwing is their summary. At this time, its logic becomes very clear.
2 Inline functions
Relative to extracting functions, in some cases we need to do the opposite. For example, the internal code and function name of some functions are clear and readable, but the internal implementation is refactored and becomes clear. Then such refactoring is redundant. The function should be removed and the code in it should be used directly.
Before refactoring:
def report_lines(a_customer): lines = [] gather_customer_data(lines, a_customer) return lines def gather_customer_data(out, a_customer): out.append({<!-- -->"name": a_customer["name"]}) out.append({<!-- -->"location": a_customer["location"]}) print(report_lines({<!-- -->"name": "zhangsan", "location": "GuangDong Province"}))
After refactoring:
def report_lines(a_customer): lines = [] lines.append({<!-- -->"name": a_customer["name"]}) lines.append({<!-- -->"location": a_customer["location"]}) return lines print(report_lines({<!-- -->"name": "zhangsan", "location": "GuangDong Province"}))
3 Refining variables
When an expression is very complex and difficult to read, we can use local variables to replace the expression for better expression. You may think that adding a variable is redundant and takes up some memory space. But in fact, the space occupied by this part is negligible. Although using complex expressions will make the code appear very concise, it is difficult to read, and such code still has a bad taste.
Before refactoring:
def price(order): # price is base price - quantity discount + shipping return order["quantity"] * order["item_price"] - \ max(0, order["quantity"] - 500) * order["item_price"] * 0.05 + \ min(order["quantity"] * order["item_price"] * 0.1, 100) print(price({<!-- -->"quantity": 20, "item_price": 3.5}))
It is mentioned in the book “Refactoring” that if you feel you need to add comments, you may wish to refactor first. If the logic is clear after refactoring, readers can clarify the logic through the code structure and function names, and no comments are needed.
This is the case here. The bloated calculation expression is difficult to understand why it is calculated in this way without comments.
After refactoring:
def price(order): # price is base price - quantity discount + shipping base_price = order["quantity"] * order["item_price"] quantity_discount = max(0, order["quantity"] - 500) * order["item_price"] * 0.05 shipping = min(order["quantity"] * order["item_price"] * 0.1, 100) return base_price - quantity_discount + shipping print(price({<!-- -->"quantity": 20, "item_price": 3.5}))
When inside a class we can distill these variables into methods
Before refactoring:
class Order(object): def __init__(self, a_record): self._data = a_record def quantity(self): return self._data["quantity"] def item_price(self): return self._data["item_price"] def price(self): return self.quantity() * self.item_price() - \ max(0, self.quantity() - 500) * self.item_price() * 0.05 + \ min(self.quantity() * self.item_price() * 0.1, 100) order = Order({<!-- -->"quantity": 20, "item_price": 3.5}) print(order.price())
After refactoring:
class Order(object): def __init__(self, a_record): self._data = a_record def quantity(self): return self._data["quantity"] def item_price(self): return self._data["item_price"] def base_price(self): return self.quantity() * self.item_price() def quantity_discount(self): return max(0, self.quantity() - 500) * self.item_price() * 0.05 def shipping(self): return min(self.quantity() * self.item_price() * 0.1, 100) def price(self): return self.base_price() - self.quantity_discount() + self.shipping() order = Order({<!-- -->"quantity": 20, "item_price": 3.5}) print(order.price())
4 Inline variables
Compared to refining variables, sometimes we also need to inline variables
Before refactoring:
base_price = a_order["base_price"] return base_price > 1000
After refactoring:
return a_order["base_price"] > 1000
5 Change function declaration
A good function name can intuitively indicate the function’s function. However, in our work, we often encounter messy function names written by predecessors without comments. In this case we need to reconstruct the function name.
A relatively simple approach is to modify the function name directly and modify the function name at the calling point as well.
Another migration approach is as follows:
Before refactoring:
def circum(radius): return 2 * math.PI * radius
After refactoring:
def circum(radius): return circumference(radius) def circumference(radius): return 2 * math.PI * radius
Modify all calls to circumference to point to circumference. After the test is correct, delete the old function.
There is also a special case where the refactored function needs to add new parameters.
Before refactoring:
_reservations = [] def add_reservation(customer): zz_add_reservation(customer) def zz_add_reservation(customer): _reservations.append(customer)
After refactoring:
_reservations = [] def add_reservation(customer): zz_add_reservation(customer, False) def zz_add_reservation(customer, is_priority): assert(is_priority == True || is_priority == False) _reservations.append(customer)
Usually, before modifying the caller, it is a good habit to introduce assertions to ensure that the caller will definitely use this new parameter.
6 Encapsulated variables
Migrating functions is easier, but data is much more troublesome. If you move the data, you must also modify all code that references the data. If the accessible range of data becomes larger, the difficulty of reconstruction will become larger, and global data is the cause of big trouble.
For this problem, the best way is to encapsulate all access to data in the form of functions
Before refactoring:
default_owner = {<!-- -->"first_name": "Martin", "last_name": "fowler"} space_ship.owner = default_owner # update data default_owner = {<!-- -->"first_name": "Rebecca", "last_name": "Parsons"}
After refactoring:
default_owner = {<!-- -->"first_name": "Martin", "last_name": "fowler"} def get_default_owner(): return default_owner def set_default_owner(arg): default_owner = arg space_ship.owner = get_default_owner() # update data set_default_owner({<!-- -->"first_name": "Rebecca", "last_name": "Parsons"})
7 Variable rename
Good naming is the core of clean programming. In order to improve the readability of the program, some bad variable names of previous generations should be renamed.
If the variable is widely used, you should consider using encapsulated variables to encapsulate it, then find all the code that uses the variable and modify it one by one.
8 Introduce parameter object
We will find that some data items always appear together in one function after another. Organizing them into a data structure will make the relationship between the data items clear. At the same time, the parameter list of the function can also be shortened.
After refactoring, all uses of this data structure will access its elements through the same name, thereby improving code consistency.
Before refactoring:
station = {<!-- --> "name": "ZB1", "readings": [ {<!-- -->"temp": 47, "time": "2016-11-10 09:10"}, {<!-- -->"temp": 53, "time": "2016-11-10 09:20"}, {<!-- -->"temp": 58, "time": "2016-11-10 09:30"}, {<!-- -->"temp": 53, "time": "2016-11-10 09:40"}, {<!-- -->"temp": 51, "time": "2016-11-10 09:50"}, ] } operating_plan = {<!-- --> "temperature_floor": 50, "temperature_ceiling": 54, } def reading_outside_range(station, min, max): res = [] for info in station["readings"]: if info["temp"] < min or info["temp"] > max: res.append(info["temp"]) return res alerts = reading_outside_range(station, operating_plan["temperature_floor"], operating_plan["temperature_ceiling"]) print(alerts)
A relatively simple way to reconstruct min and max is to encapsulate it into a class. At the same time, we can also add a method to the class for testing reading_outside_range
After refactoring:
station = {<!-- --> "name": "ZB1", "readings": [ {<!-- -->"temp": 47, "time": "2016-11-10 09:10"}, {<!-- -->"temp": 53, "time": "2016-11-10 09:20"}, {<!-- -->"temp": 58, "time": "2016-11-10 09:30"}, {<!-- -->"temp": 53, "time": "2016-11-10 09:40"}, {<!-- -->"temp": 51, "time": "2016-11-10 09:50"}, ] } operating_plan = {<!-- --> "temperature_floor": 50, "temperature_ceiling": 54, } class NumberRange(object): def __init__(self, min, max): self._data = {<!-- -->"min": min, "max": max} def get_min(self): return self._data["min"] def get_max(self): return self._data["max"] def contains(self, temp): return temp < self._data["min"] or temp > self._data["max"] def reading_outside_range(station, range): res = [] for info in station["readings"]: if range.contains(info["temp"]): res.append(info["temp"]) return res range = NumberRange(50, 54) alerts = reading_outside_range(station, range) print(alerts)
9 Functions are combined into classes
There is a scenario where a group of functions operate on the same piece of data (usually this piece of data is passed to the function as a parameter), then we can assemble these functions into a class. Classes can clearly provide a common environment for these functions. Calling these functions within the object can save many parameters, thereby simplifying function calls, and such an object can be passed to other parts of the system more conveniently.
Before refactoring:
reading = {<!-- --> "customer": "ivan", "quantity": 10, "month": 5, "year": 2017, } def acquire_reading(): return reading def base_rate(month, year): return year/month def tax_threshold(year): return year/2 def calculate_base_charge(a_reading): return base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] a_reading = acquire_reading() base_charge = base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] a_reading = acquire_reading() base = base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] taxable_charge = max(0, base - tax_threshold(a_reading["year"])) a_reading = acquire_reading() basic_charge_amount = calculate_base_charge(a_reading) print(base_charge) print(taxable_charge) print(basic_charge_amount)
After refactoring:
reading = {<!-- --> "customer": "ivan", "quantity": 10, "month": 5, "year": 2017, } class Reading(object): def __init__(self, data): self._customer = data["customer"] self._quantity = data["quantity"] self._month = data["month"] self._year = data["year"] def get_customer(self): return self._customer def get_quantity(self): return self._quantity def get_month(self): return self._month def get_year(self): return self._year def base_rate(self): return self._year / self._month def get_calculate_base_charge(self): return self.base_rate() * self._quantity def tax_threshold(self): return self._year / 2 def acquire_reading(): return reading raw_reading = acquire_reading() a_reading = Reading(raw_reading) base_charge = a_reading.get_calculate_base_charge() base = a_reading.get_calculate_base_charge() taxable_charge = max(0, base - a_reading.tax_threshold()) basic_charge_amount = a_reading.get_calculate_base_charge() print(base_charge) print(taxable_charge) print(basic_charge_amount)
10 functions combined into transformations
There is another reconstruction solution for problem 9, which is to combine functions into transformations. To put it simply, it is to abandon assembling functions into classes and assemble them into a function. In this function, the assembled functions are enhanced and transformed.
Before refactoring:
reading = {<!-- --> "customer": "ivan", "quantity": 10, "month": 5, "year": 2017, } def acquire_reading(): return reading def base_rate(month, year): return year/month def tax_threshold(year): return year/2 def calculate_base_charge(a_reading): return base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] a_reading = acquire_reading() base_charge = base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] a_reading = acquire_reading() base = base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] taxable_charge = max(0, base - tax_threshold(a_reading["year"])) a_reading = acquire_reading() basic_charge_amount = calculate_base_charge(a_reading) print(base_charge) print(taxable_charge) print(basic_charge_amount)
After refactoring:
reading = {<!-- --> "customer": "ivan", "quantity": 10, "month": 5, "year": 2017, } def acquire_reading(): return reading def base_rate(month, year): return year/month def tax_threshold(year): return year/2 def calculate_base_charge(a_reading): return base_rate(a_reading["month"], a_reading["year"]) * a_reading["quantity"] def enrich_reading(original): original["base_charge"] = calculate_base_charge(original) original["taxable_charge"] = max(0, original["base_charge"] - tax_threshold(original["year"])) return original raw_reading = acquire_reading() a_reading = enrich_reading(raw_reading) base_charge = a_reading["base_charge"] taxable_charge = a_reading["taxable_charge"] basic_charge_amount = calculate_base_charge(a_reading) print(base_charge) print(taxable_charge) print(basic_charge_amount)
11 Split Phase
If there is a piece of code that does multiple different things at the same time, then we habitually split it into independent modules. This way when it comes time to revise, we can tackle each topic individually.
Before refactoring:
def price_order(product, quantity, shipping_method): base_price = product["base_price"] * quantity discount = max(quantity - product["discount_threshold"], 0) * product["base_price"] * product["discount_rate"] shipping_per_case = shipping_method["discounted_fee"] if base_price > shipping_method["discount_threshold"] else shipping_method["fee_per_case"] shipping_cost = quantity * shipping_per_case price = base_price - discount + shipping_cost return price
The first two lines in this code calculate the price related to the product in the order based on the product information, and the next two lines calculate the delivery cost based on the shipping information. Therefore, these two pieces of logic can be split.
def price_order(product, quantity, shipping_method): base_price = product["base_price"] * quantity discount = max(quantity - product["discount_threshold"], 0) * product["base_price"] * product["discount_rate"] price = apply_shipping(base_price, shipping_method, quantity, discount) return price def apply_shipping(base_price, shipping_method, quantity, discount): shipping_per_case = shipping_method["discounted_fee"] if base_price > shipping_method["discount_threshold"] else shipping_method["fee_per_case"] shipping_cost = quantity * shipping_per_case price = base_price - discount + shipping_cost return price
Next we can pass in the required data in the form of parameters
def price_order(product, quantity, shipping_method): base_price = product["base_price"] * quantity discount = max(quantity - product["discount_threshold"], 0) * product["base_price"] * product["discount_rate"] price_data = {<!-- -->"base_price": base_price, "quantity": quantity, "discount": discount} price = apply_shipping(price_data, shipping_method) return price def apply_shipping(price_data, shipping_method): shipping_per_case = shipping_method["discounted_fee"] if price_data["base_price"] > shipping_method["discount_threshold"] else shipping_method["fee_per_case"] shipping_cost = price_data["quantity"] * shipping_per_case price = price_data["base_price"] - price_data["discount"] + shipping_cost return price
Finally, we can separate the first stage code into a function
def price_order(product, quantity, shipping_method): price_data = calculate_pricing_data(product, quantity) return apply_shipping(price_data, shipping_method) def calculate_pricing_data(product, quantity): base_price = product["base_price"] * quantity discount = max(quantity - product["discount_threshold"], 0) * product["base_price"] * product["discount_rate"] return {<!-- -->"base_price": base_price, "quantity": quantity, "discount": discount} def apply_shipping(price_data, shipping_method): shipping_per_case = shipping_method["discounted_fee"] if price_data["base_price"] > shipping_method["discount_threshold"] else shipping_method["fee_per_case"] shipping_cost = price_data["quantity"] * shipping_per_case return price_data["base_price"] - price_data["discount"] + shipping_cost