Merge pull request #30677 from ankush/refactor/stock_balance
refactor: stock balance report
diff --git a/erpnext/stock/report/stock_balance/stock_balance.py b/erpnext/stock/report/stock_balance/stock_balance.py
index afbc6fe..6369f91 100644
--- a/erpnext/stock/report/stock_balance/stock_balance.py
+++ b/erpnext/stock/report/stock_balance/stock_balance.py
@@ -3,24 +3,41 @@
from operator import itemgetter
+from typing import Any, Dict, List, Optional, TypedDict
import frappe
from frappe import _
+from frappe.query_builder.functions import CombineDatetime
from frappe.utils import cint, date_diff, flt, getdate
+from frappe.utils.nestedset import get_descendants_of
+from pypika.terms import ExistsCriterion
import erpnext
from erpnext.stock.report.stock_ageing.stock_ageing import FIFOSlots, get_average_age
-from erpnext.stock.report.stock_ledger.stock_ledger import get_item_group_condition
from erpnext.stock.utils import add_additional_uom_columns, is_reposting_item_valuation_in_progress
-def execute(filters=None):
+class StockBalanceFilter(TypedDict):
+ company: Optional[str]
+ from_date: str
+ to_date: str
+ item_group: Optional[str]
+ item: Optional[str]
+ warehouse: Optional[str]
+ warehouse_type: Optional[str]
+ include_uom: Optional[str] # include extra info in converted UOM
+ show_stock_ageing_data: bool
+ show_variant_attributes: bool
+
+
+SLEntry = Dict[str, Any]
+
+
+def execute(filters: Optional[StockBalanceFilter] = None):
is_reposting_item_valuation_in_progress()
if not filters:
filters = {}
- to_date = filters.get("to_date")
-
if filters.get("company"):
company_currency = erpnext.get_company_currency(filters.get("company"))
else:
@@ -48,6 +65,7 @@
_func = itemgetter(1)
+ to_date = filters.get("to_date")
for (company, item, warehouse) in sorted(iwb_map):
if item_map.get(item):
qty_dict = iwb_map[(company, item, warehouse)]
@@ -92,7 +110,7 @@
return columns, data
-def get_columns(filters):
+def get_columns(filters: StockBalanceFilter):
"""return columns"""
columns = [
{
@@ -215,66 +233,77 @@
return columns
-def get_conditions(filters):
- conditions = ""
+def apply_conditions(query, filters):
+ sle = frappe.qb.DocType("Stock Ledger Entry")
+ warehouse_table = frappe.qb.DocType("Warehouse")
+
if not filters.get("from_date"):
frappe.throw(_("'From Date' is required"))
- if filters.get("to_date"):
- conditions += " and sle.posting_date <= %s" % frappe.db.escape(filters.get("to_date"))
+ if to_date := filters.get("to_date"):
+ query = query.where(sle.posting_date <= to_date)
else:
frappe.throw(_("'To Date' is required"))
- if filters.get("company"):
- conditions += " and sle.company = %s" % frappe.db.escape(filters.get("company"))
+ if company := filters.get("company"):
+ query = query.where(sle.company == company)
- if filters.get("warehouse"):
- warehouse_details = frappe.db.get_value(
- "Warehouse", filters.get("warehouse"), ["lft", "rgt"], as_dict=1
- )
- if warehouse_details:
- conditions += (
- " and exists (select name from `tabWarehouse` wh \
- where wh.lft >= %s and wh.rgt <= %s and sle.warehouse = wh.name)"
- % (warehouse_details.lft, warehouse_details.rgt)
+ if warehouse := filters.get("warehouse"):
+ lft, rgt = frappe.db.get_value("Warehouse", warehouse, ["lft", "rgt"])
+ chilren_subquery = (
+ frappe.qb.from_(warehouse_table)
+ .select(warehouse_table.name)
+ .where(
+ (warehouse_table.lft >= lft)
+ & (warehouse_table.rgt <= rgt)
+ & (warehouse_table.name == sle.warehouse)
)
-
- if filters.get("warehouse_type") and not filters.get("warehouse"):
- conditions += (
- " and exists (select name from `tabWarehouse` wh \
- where wh.warehouse_type = '%s' and sle.warehouse = wh.name)"
- % (filters.get("warehouse_type"))
+ )
+ query = query.where(ExistsCriterion(chilren_subquery))
+ elif warehouse_type := filters.get("warehouse_type"):
+ query = (
+ query.join(warehouse_table)
+ .on(warehouse_table.name == sle.warehouse)
+ .where(warehouse_table.warehouse_type == warehouse_type)
)
- return conditions
+ return query
-def get_stock_ledger_entries(filters, items):
- item_conditions_sql = ""
- if items:
- item_conditions_sql = " and sle.item_code in ({})".format(
- ", ".join(frappe.db.escape(i, percent=False) for i in items)
+def get_stock_ledger_entries(filters: StockBalanceFilter, items: List[str]) -> List[SLEntry]:
+ sle = frappe.qb.DocType("Stock Ledger Entry")
+
+ query = (
+ frappe.qb.from_(sle)
+ .select(
+ sle.item_code,
+ sle.warehouse,
+ sle.posting_date,
+ sle.actual_qty,
+ sle.valuation_rate,
+ sle.company,
+ sle.voucher_type,
+ sle.qty_after_transaction,
+ sle.stock_value_difference,
+ sle.item_code.as_("name"),
+ sle.voucher_no,
+ sle.stock_value,
+ sle.batch_no,
)
-
- conditions = get_conditions(filters)
-
- return frappe.db.sql(
- """
- select
- sle.item_code, warehouse, sle.posting_date, sle.actual_qty, sle.valuation_rate,
- sle.company, sle.voucher_type, sle.qty_after_transaction, sle.stock_value_difference,
- sle.item_code as name, sle.voucher_no, sle.stock_value, sle.batch_no
- from
- `tabStock Ledger Entry` sle
- where sle.docstatus < 2 %s %s
- and is_cancelled = 0
- order by sle.posting_date, sle.posting_time, sle.creation, sle.actual_qty"""
- % (item_conditions_sql, conditions), # nosec
- as_dict=1,
+ .where((sle.docstatus < 2) & (sle.is_cancelled == 0))
+ .orderby(CombineDatetime(sle.posting_date, sle.posting_time))
+ .orderby(sle.creation)
+ .orderby(sle.actual_qty)
)
+ if items:
+ query = query.where(sle.item_code.isin(items))
-def get_item_warehouse_map(filters, sle):
+ query = apply_conditions(query, filters)
+ return query.run(as_dict=True)
+
+
+def get_item_warehouse_map(filters: StockBalanceFilter, sle: List[SLEntry]):
iwb_map = {}
from_date = getdate(filters.get("from_date"))
to_date = getdate(filters.get("to_date"))
@@ -332,7 +361,7 @@
return iwb_map
-def filter_items_with_no_transactions(iwb_map, float_precision):
+def filter_items_with_no_transactions(iwb_map, float_precision: float):
for (company, item, warehouse) in sorted(iwb_map):
qty_dict = iwb_map[(company, item, warehouse)]
@@ -349,26 +378,22 @@
return iwb_map
-def get_items(filters):
+def get_items(filters: StockBalanceFilter) -> List[str]:
"Get items based on item code, item group or brand."
- conditions = []
- if filters.get("item_code"):
- conditions.append("item.name=%(item_code)s")
+ if item_code := filters.get("item_code"):
+ return [item_code]
else:
- if filters.get("item_group"):
- conditions.append(get_item_group_condition(filters.get("item_group")))
- if filters.get("brand"): # used in stock analytics report
- conditions.append("item.brand=%(brand)s")
+ item_filters = {}
+ if item_group := filters.get("item_group"):
+ children = get_descendants_of("Item Group", item_group, ignore_permissions=True)
+ item_filters["item_group"] = ("in", children + [item_group])
+ if brand := filters.get("brand"):
+ item_filters["brand"] = brand
- items = []
- if conditions:
- items = frappe.db.sql_list(
- """select name from `tabItem` item where {}""".format(" and ".join(conditions)), filters
- )
- return items
+ return frappe.get_all("Item", filters=item_filters, pluck="name", order_by=None)
-def get_item_details(items, sle, filters):
+def get_item_details(items: List[str], sle: List[SLEntry], filters: StockBalanceFilter):
item_details = {}
if not items:
items = list(set(d.item_code for d in sle))
@@ -376,33 +401,35 @@
if not items:
return item_details
- cf_field = cf_join = ""
- if filters.get("include_uom"):
- cf_field = ", ucd.conversion_factor"
- cf_join = (
- "left join `tabUOM Conversion Detail` ucd on ucd.parent=item.name and ucd.uom=%s"
- % frappe.db.escape(filters.get("include_uom"))
- )
+ item_table = frappe.qb.DocType("Item")
- res = frappe.db.sql(
- """
- select
- item.name, item.item_name, item.description, item.item_group, item.brand, item.stock_uom %s
- from
- `tabItem` item
- %s
- where
- item.name in (%s)
- """
- % (cf_field, cf_join, ",".join(["%s"] * len(items))),
- items,
- as_dict=1,
+ query = (
+ frappe.qb.from_(item_table)
+ .select(
+ item_table.name,
+ item_table.item_name,
+ item_table.description,
+ item_table.item_group,
+ item_table.brand,
+ item_table.stock_uom,
+ )
+ .where(item_table.name.isin(items))
)
- for item in res:
- item_details.setdefault(item.name, item)
+ if uom := filters.get("include_uom"):
+ uom_conv_detail = frappe.qb.DocType("UOM Conversion Detail")
+ query = (
+ query.left_join(uom_conv_detail)
+ .on((uom_conv_detail.parent == item_table.name) & (uom_conv_detail.uom == uom))
+ .select(uom_conv_detail.conversion_factor)
+ )
- if filters.get("show_variant_attributes", 0) == 1:
+ result = query.run(as_dict=1)
+
+ for item_table in result:
+ item_details.setdefault(item_table.name, item_table)
+
+ if filters.get("show_variant_attributes"):
variant_values = get_variant_values_for(list(item_details))
item_details = {k: v.update(variant_values.get(k, {})) for k, v in item_details.items()}
@@ -413,36 +440,33 @@
item_reorder_details = frappe._dict()
if items:
- item_reorder_details = frappe.db.sql(
- """
- select parent, warehouse, warehouse_reorder_qty, warehouse_reorder_level
- from `tabItem Reorder`
- where parent in ({0})
- """.format(
- ", ".join(frappe.db.escape(i, percent=False) for i in items)
- ),
- as_dict=1,
+ item_reorder_details = frappe.get_all(
+ "Item Reorder",
+ ["parent", "warehouse", "warehouse_reorder_qty", "warehouse_reorder_level"],
+ filters={"parent": ("in", items)},
)
return dict((d.parent + d.warehouse, d) for d in item_reorder_details)
-def get_variants_attributes():
+def get_variants_attributes() -> List[str]:
"""Return all item variant attributes."""
- return [i.name for i in frappe.get_all("Item Attribute")]
+ return frappe.get_all("Item Attribute", pluck="name")
def get_variant_values_for(items):
"""Returns variant values for items."""
attribute_map = {}
- for attr in frappe.db.sql(
- """select parent, attribute, attribute_value
- from `tabItem Variant Attribute` where parent in (%s)
- """
- % ", ".join(["%s"] * len(items)),
- tuple(items),
- as_dict=1,
- ):
+
+ attribute_info = frappe.get_all(
+ "Item Variant Attribute",
+ ["parent", "attribute", "attribute_value"],
+ {
+ "parent": ("in", items),
+ },
+ )
+
+ for attr in attribute_info:
attribute_map.setdefault(attr["parent"], {})
attribute_map[attr["parent"]].update({attr["attribute"]: attr["attribute_value"]})
diff --git a/erpnext/stock/report/stock_balance/test_stock_balance.py b/erpnext/stock/report/stock_balance/test_stock_balance.py
new file mode 100644
index 0000000..e963de2
--- /dev/null
+++ b/erpnext/stock/report/stock_balance/test_stock_balance.py
@@ -0,0 +1,174 @@
+from typing import Any, Dict
+
+import frappe
+from frappe import _dict
+from frappe.tests.utils import FrappeTestCase
+from frappe.utils import today
+
+from erpnext.stock.doctype.item.test_item import make_item
+from erpnext.stock.doctype.stock_entry.stock_entry_utils import make_stock_entry
+from erpnext.stock.report.stock_balance.stock_balance import execute
+
+
+def stock_balance(filters):
+ """Get rows from stock balance report"""
+ return [_dict(row) for row in execute(filters)[1]]
+
+
+class TestStockBalance(FrappeTestCase):
+ # ----------- utils
+
+ def setUp(self):
+ self.item = make_item()
+ self.filters = _dict(
+ {
+ "company": "_Test Company",
+ "item_code": self.item.name,
+ "from_date": "2020-01-01",
+ "to_date": str(today()),
+ }
+ )
+
+ def tearDown(self):
+ frappe.db.rollback()
+
+ def assertPartialDictEq(self, expected: Dict[str, Any], actual: Dict[str, Any]):
+ for k, v in expected.items():
+ self.assertEqual(v, actual[k], msg=f"{expected=}\n{actual=}")
+
+ def generate_stock_ledger(self, item_code: str, movements):
+
+ for movement in map(_dict, movements):
+ if "to_warehouse" not in movement:
+ movement.to_warehouse = "_Test Warehouse - _TC"
+ make_stock_entry(item_code=item_code, **movement)
+
+ def assertInvariants(self, rows):
+ last_balance = frappe.db.sql(
+ """
+ WITH last_balances AS (
+ SELECT item_code, warehouse,
+ stock_value, qty_after_transaction,
+ ROW_NUMBER() OVER (PARTITION BY item_code, warehouse
+ ORDER BY timestamp(posting_date, posting_time) desc, creation desc)
+ AS rn
+ FROM `tabStock Ledger Entry`
+ where is_cancelled=0
+ )
+ SELECT * FROM last_balances WHERE rn = 1""",
+ as_dict=True,
+ )
+
+ item_wh_stock = _dict()
+
+ for line in last_balance:
+ item_wh_stock.setdefault((line.item_code, line.warehouse), line)
+
+ for row in rows:
+ msg = f"Invariants not met for {rows=}"
+ # qty invariant
+ self.assertAlmostEqual(row.bal_qty, row.opening_qty + row.in_qty - row.out_qty, msg)
+
+ # value invariant
+ self.assertAlmostEqual(row.bal_val, row.opening_val + row.in_val - row.out_val, msg)
+
+ # check against SLE
+ last_sle = item_wh_stock[(row.item_code, row.warehouse)]
+ self.assertAlmostEqual(row.bal_qty, last_sle.qty_after_transaction, 3)
+ self.assertAlmostEqual(row.bal_val, last_sle.stock_value, 3)
+
+ # valuation rate
+ if not row.bal_qty:
+ continue
+ self.assertAlmostEqual(row.val_rate, row.bal_val / row.bal_qty, 3, msg)
+
+ # ----------- tests
+
+ def test_basic_stock_balance(self):
+ """Check very basic functionality and item info"""
+ rows = stock_balance(self.filters)
+ self.assertEqual(rows, [])
+
+ self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10)])
+
+ # check item info
+ rows = stock_balance(self.filters)
+ self.assertPartialDictEq(
+ {
+ "item_code": self.item.name,
+ "item_name": self.item.item_name,
+ "item_group": self.item.item_group,
+ "stock_uom": self.item.stock_uom,
+ "in_qty": 5,
+ "in_val": 50,
+ "val_rate": 10,
+ },
+ rows[0],
+ )
+ self.assertInvariants(rows)
+
+ def test_opening_balance(self):
+ self.generate_stock_ledger(
+ self.item.name,
+ [
+ _dict(qty=1, rate=1, posting_date="2021-01-01"),
+ _dict(qty=2, rate=2, posting_date="2021-01-02"),
+ _dict(qty=3, rate=3, posting_date="2021-01-03"),
+ ],
+ )
+ rows = stock_balance(self.filters)
+ self.assertInvariants(rows)
+
+ rows = stock_balance(self.filters.update({"from_date": "2021-01-02"}))
+ self.assertInvariants(rows)
+ self.assertPartialDictEq({"opening_qty": 1, "in_qty": 5}, rows[0])
+
+ rows = stock_balance(self.filters.update({"from_date": "2022-01-01"}))
+ self.assertInvariants(rows)
+ self.assertPartialDictEq({"opening_qty": 6, "in_qty": 0}, rows[0])
+
+ def test_uom_converted_info(self):
+
+ self.item.append("uoms", {"conversion_factor": 5, "uom": "Box"})
+ self.item.save()
+
+ self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10)])
+
+ rows = stock_balance(self.filters.update({"include_uom": "Box"}))
+ self.assertEqual(rows[0].bal_qty_alt, 1)
+ self.assertInvariants(rows)
+
+ def test_item_group(self):
+ self.filters.pop("item_code", None)
+ rows = stock_balance(self.filters.update({"item_group": self.item.item_group}))
+ self.assertTrue(all(r.item_group == self.item.item_group for r in rows))
+
+ def test_child_warehouse_balances(self):
+ # This is default
+ self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10, to_warehouse="Stores - _TC")])
+
+ self.filters.pop("item_code", None)
+ rows = stock_balance(self.filters.update({"warehouse": "All Warehouses - _TC"}))
+
+ self.assertTrue(
+ any(r.item_code == self.item.name and r.warehouse == "Stores - _TC" for r in rows),
+ msg=f"Expected child warehouse balances \n{rows}",
+ )
+
+ def test_show_item_attr(self):
+ from erpnext.controllers.item_variant import create_variant
+
+ self.item.has_variants = True
+ self.item.append("attributes", {"attribute": "Test Size"})
+ self.item.save()
+
+ attributes = {"Test Size": "Large"}
+ variant = create_variant(self.item.name, attributes)
+ variant.save()
+
+ self.generate_stock_ledger(variant.name, [_dict(qty=5, rate=10)])
+ rows = stock_balance(
+ self.filters.update({"show_variant_attributes": 1, "item_code": variant.name})
+ )
+ self.assertPartialDictEq(attributes, rows[0])
+ self.assertInvariants(rows)