refactor: move reposting logic to common controller
diff --git a/erpnext/accounts/doctype/sales_invoice/sales_invoice.py b/erpnext/accounts/doctype/sales_invoice/sales_invoice.py
index 7bdb2b4..87f4bd0 100644
--- a/erpnext/accounts/doctype/sales_invoice/sales_invoice.py
+++ b/erpnext/accounts/doctype/sales_invoice/sales_invoice.py
@@ -11,9 +11,6 @@
import erpnext
from erpnext.accounts.deferred_revenue import validate_service_stop_date
-from erpnext.accounts.doctype.accounting_dimension.accounting_dimension import (
- get_accounting_dimensions,
-)
from erpnext.accounts.doctype.loyalty_program.loyalty_program import (
get_loyalty_program_details_with_points,
validate_loyalty_points,
@@ -517,79 +514,34 @@
def on_update_after_submit(self):
if hasattr(self, "repost_required"):
- needs_repost = 0
-
- # Check if any field affecting accounting entry is altered
- doc_before_update = self.get_doc_before_save()
- accounting_dimensions = get_accounting_dimensions() + ["cost_center", "project"]
-
- # Check if opening entry check updated
- if doc_before_update.get("is_opening") != self.is_opening:
- needs_repost = 1
-
- if not needs_repost:
- # Parent Level Accounts excluding party account
- for field in (
- "additional_discount_account",
- "cash_bank_account",
- "account_for_change_amount",
- "write_off_account",
- "loyalty_redemption_account",
- "unrealized_profit_loss_account",
- ):
- if doc_before_update.get(field) != self.get(field):
- needs_repost = 1
- break
-
- # Check for parent accounting dimensions
- for dimension in accounting_dimensions:
- if doc_before_update.get(dimension) != self.get(dimension):
- needs_repost = 1
- break
-
- # Check for child tables
- if self.check_if_child_table_updated(
- "items",
- doc_before_update,
- ("income_account", "expense_account", "discount_account"),
- accounting_dimensions,
- ):
- needs_repost = 1
-
- if self.check_if_child_table_updated(
- "taxes", doc_before_update, ("account_head",), accounting_dimensions
- ):
- needs_repost = 1
-
+ fields_to_check = [
+ "additional_discount_account",
+ "cash_bank_account",
+ "account_for_change_amount",
+ "write_off_account",
+ "loyalty_redemption_account",
+ "unrealized_profit_loss_account",
+ ]
+ child_tables = {
+ "items": ("income_account", "expense_account", "discount_account"),
+ "taxes": ("account_head",),
+ }
+ self.needs_repost = self.check_if_fields_updated(fields_to_check, child_tables)
+ self.validate_deferred_accounting_before_repost()
self.validate_accounts()
+ self.db_set("repost_required", self.needs_repost)
- # validate if deferred revenue is enabled for any item
- # Don't allow to update the invoice if deferred revenue is enabled
- if needs_repost:
- for item in self.get("items"):
- if item.enable_deferred_revenue:
- frappe.throw(
- _(
- "Deferred Revenue is enabled for item {0}. You cannot update the invoice after submission."
- ).format(item.item_code)
- )
-
- self.db_set("repost_required", needs_repost)
-
- def check_if_child_table_updated(
- self, child_table, doc_before_update, fields_to_check, accounting_dimensions
- ):
- # Check if any field affecting accounting entry is altered
- for index, item in enumerate(self.get(child_table)):
- for field in fields_to_check:
- if doc_before_update.get(child_table)[index].get(field) != item.get(field):
- return True
-
- for dimension in accounting_dimensions:
- if doc_before_update.get(child_table)[index].get(dimension) != item.get(dimension):
- return True
-
- return False
+ def validate_deferred_accounting_before_repost(self):
+ # validate if deferred revenue is enabled for any item
+ # Don't allow to update the invoice if deferred revenue is enabled
+ if self.needs_repost:
+ for item in self.get("items"):
+ if item.enable_deferred_revenue:
+ frappe.throw(
+ _(
+ "Deferred Revenue is enabled for item {0}. You cannot update the invoice after submission."
+ ).format(item.item_code)
+ )
@frappe.whitelist()
def repost_accounting_entries(self):
diff --git a/erpnext/controllers/accounts_controller.py b/erpnext/controllers/accounts_controller.py
index e635aa7..0897864 100644
--- a/erpnext/controllers/accounts_controller.py
+++ b/erpnext/controllers/accounts_controller.py
@@ -2186,6 +2186,44 @@
_("Select finance book for the item {0} at row {1}").format(item.item_code, item.idx)
)
+ def check_if_fields_updated(self, fields_to_check, child_tables):
+ # Check if any field affecting accounting entry is altered
+ doc_before_update = self.get_doc_before_save()
+ accounting_dimensions = get_accounting_dimensions() + ["cost_center", "project"]
+
+ # Check if opening entry check updated
+ needs_repost = doc_before_update.get("is_opening") != self.is_opening
+
+ if not needs_repost:
+ # Parent Level Accounts excluding party account
+ fields_to_check += accounting_dimensions
+ for field in fields_to_check:
+ if doc_before_update.get(field) != self.get(field):
+ needs_repost = 1
+ break
+
+ if not needs_repost:
+ # Check for child tables
+ for table in child_tables:
+ needs_repost = check_if_child_table_updated(
+ doc_before_update.get(table), self.get(table), child_tables[table]
+ )
+ if needs_repost:
+ break
+
+ return needs_repost
+
+ @frappe.whitelist()
+ def repost_accounting_entries(self):
+ if self.repost_required:
+ self.docstatus = 2
+ self.make_gl_entries_on_cancel()
+ self.docstatus = 1
+ self.make_gl_entries()
+ self.db_set("repost_required", 0)
+ else:
+ frappe.throw(_("No updates pending for reposting"))
+
@frappe.whitelist()
def get_tax_rate(account_head):
@@ -3191,6 +3229,23 @@
parent.create_stock_reservation_entries()
+def check_if_child_table_updated(
+ child_table_before_update, child_table_after_update, fields_to_check
+):
+ accounting_dimensions = get_accounting_dimensions() + ["cost_center", "project"]
+ # Check if any field affecting accounting entry is altered
+ for index, item in enumerate(child_table_after_update):
+ for field in fields_to_check:
+ if child_table_before_update[index].get(field) != item.get(field):
+ return True
+
+ for dimension in accounting_dimensions:
+ if child_table_before_update[index].get(dimension) != item.get(dimension):
+ return True
+
+ return False
+
+
@erpnext.allow_regional
def validate_regional(doc):
pass