Merge pull request #30677 from ankush/refactor/stock_balance

refactor: stock balance report
diff --git a/erpnext/stock/report/stock_balance/stock_balance.py b/erpnext/stock/report/stock_balance/stock_balance.py
index afbc6fe..6369f91 100644
--- a/erpnext/stock/report/stock_balance/stock_balance.py
+++ b/erpnext/stock/report/stock_balance/stock_balance.py
@@ -3,24 +3,41 @@
 
 
 from operator import itemgetter
+from typing import Any, Dict, List, Optional, TypedDict
 
 import frappe
 from frappe import _
+from frappe.query_builder.functions import CombineDatetime
 from frappe.utils import cint, date_diff, flt, getdate
+from frappe.utils.nestedset import get_descendants_of
+from pypika.terms import ExistsCriterion
 
 import erpnext
 from erpnext.stock.report.stock_ageing.stock_ageing import FIFOSlots, get_average_age
-from erpnext.stock.report.stock_ledger.stock_ledger import get_item_group_condition
 from erpnext.stock.utils import add_additional_uom_columns, is_reposting_item_valuation_in_progress
 
 
-def execute(filters=None):
+class StockBalanceFilter(TypedDict):
+	company: Optional[str]
+	from_date: str
+	to_date: str
+	item_group: Optional[str]
+	item: Optional[str]
+	warehouse: Optional[str]
+	warehouse_type: Optional[str]
+	include_uom: Optional[str]  # include extra info in converted UOM
+	show_stock_ageing_data: bool
+	show_variant_attributes: bool
+
+
+SLEntry = Dict[str, Any]
+
+
+def execute(filters: Optional[StockBalanceFilter] = None):
 	is_reposting_item_valuation_in_progress()
 	if not filters:
 		filters = {}
 
-	to_date = filters.get("to_date")
-
 	if filters.get("company"):
 		company_currency = erpnext.get_company_currency(filters.get("company"))
 	else:
@@ -48,6 +65,7 @@
 
 	_func = itemgetter(1)
 
+	to_date = filters.get("to_date")
 	for (company, item, warehouse) in sorted(iwb_map):
 		if item_map.get(item):
 			qty_dict = iwb_map[(company, item, warehouse)]
@@ -92,7 +110,7 @@
 	return columns, data
 
 
-def get_columns(filters):
+def get_columns(filters: StockBalanceFilter):
 	"""return columns"""
 	columns = [
 		{
@@ -215,66 +233,77 @@
 	return columns
 
 
-def get_conditions(filters):
-	conditions = ""
+def apply_conditions(query, filters):
+	sle = frappe.qb.DocType("Stock Ledger Entry")
+	warehouse_table = frappe.qb.DocType("Warehouse")
+
 	if not filters.get("from_date"):
 		frappe.throw(_("'From Date' is required"))
 
-	if filters.get("to_date"):
-		conditions += " and sle.posting_date <= %s" % frappe.db.escape(filters.get("to_date"))
+	if to_date := filters.get("to_date"):
+		query = query.where(sle.posting_date <= to_date)
 	else:
 		frappe.throw(_("'To Date' is required"))
 
-	if filters.get("company"):
-		conditions += " and sle.company = %s" % frappe.db.escape(filters.get("company"))
+	if company := filters.get("company"):
+		query = query.where(sle.company == company)
 
-	if filters.get("warehouse"):
-		warehouse_details = frappe.db.get_value(
-			"Warehouse", filters.get("warehouse"), ["lft", "rgt"], as_dict=1
-		)
-		if warehouse_details:
-			conditions += (
-				" and exists (select name from `tabWarehouse` wh \
-				where wh.lft >= %s and wh.rgt <= %s and sle.warehouse = wh.name)"
-				% (warehouse_details.lft, warehouse_details.rgt)
+	if warehouse := filters.get("warehouse"):
+		lft, rgt = frappe.db.get_value("Warehouse", warehouse, ["lft", "rgt"])
+		chilren_subquery = (
+			frappe.qb.from_(warehouse_table)
+			.select(warehouse_table.name)
+			.where(
+				(warehouse_table.lft >= lft)
+				& (warehouse_table.rgt <= rgt)
+				& (warehouse_table.name == sle.warehouse)
 			)
-
-	if filters.get("warehouse_type") and not filters.get("warehouse"):
-		conditions += (
-			" and exists (select name from `tabWarehouse` wh \
-			where wh.warehouse_type = '%s' and sle.warehouse = wh.name)"
-			% (filters.get("warehouse_type"))
+		)
+		query = query.where(ExistsCriterion(chilren_subquery))
+	elif warehouse_type := filters.get("warehouse_type"):
+		query = (
+			query.join(warehouse_table)
+			.on(warehouse_table.name == sle.warehouse)
+			.where(warehouse_table.warehouse_type == warehouse_type)
 		)
 
-	return conditions
+	return query
 
 
-def get_stock_ledger_entries(filters, items):
-	item_conditions_sql = ""
-	if items:
-		item_conditions_sql = " and sle.item_code in ({})".format(
-			", ".join(frappe.db.escape(i, percent=False) for i in items)
+def get_stock_ledger_entries(filters: StockBalanceFilter, items: List[str]) -> List[SLEntry]:
+	sle = frappe.qb.DocType("Stock Ledger Entry")
+
+	query = (
+		frappe.qb.from_(sle)
+		.select(
+			sle.item_code,
+			sle.warehouse,
+			sle.posting_date,
+			sle.actual_qty,
+			sle.valuation_rate,
+			sle.company,
+			sle.voucher_type,
+			sle.qty_after_transaction,
+			sle.stock_value_difference,
+			sle.item_code.as_("name"),
+			sle.voucher_no,
+			sle.stock_value,
+			sle.batch_no,
 		)
-
-	conditions = get_conditions(filters)
-
-	return frappe.db.sql(
-		"""
-		select
-			sle.item_code, warehouse, sle.posting_date, sle.actual_qty, sle.valuation_rate,
-			sle.company, sle.voucher_type, sle.qty_after_transaction, sle.stock_value_difference,
-			sle.item_code as name, sle.voucher_no, sle.stock_value, sle.batch_no
-		from
-			`tabStock Ledger Entry` sle
-		where sle.docstatus < 2 %s %s
-		and is_cancelled = 0
-		order by sle.posting_date, sle.posting_time, sle.creation, sle.actual_qty"""
-		% (item_conditions_sql, conditions),  # nosec
-		as_dict=1,
+		.where((sle.docstatus < 2) & (sle.is_cancelled == 0))
+		.orderby(CombineDatetime(sle.posting_date, sle.posting_time))
+		.orderby(sle.creation)
+		.orderby(sle.actual_qty)
 	)
 
+	if items:
+		query = query.where(sle.item_code.isin(items))
 
-def get_item_warehouse_map(filters, sle):
+	query = apply_conditions(query, filters)
+	return query.run(as_dict=True)
+
+
+def get_item_warehouse_map(filters: StockBalanceFilter, sle: List[SLEntry]):
 	iwb_map = {}
 	from_date = getdate(filters.get("from_date"))
 	to_date = getdate(filters.get("to_date"))
@@ -332,7 +361,7 @@
 	return iwb_map
 
 
-def filter_items_with_no_transactions(iwb_map, float_precision):
+def filter_items_with_no_transactions(iwb_map, float_precision: float):
 	for (company, item, warehouse) in sorted(iwb_map):
 		qty_dict = iwb_map[(company, item, warehouse)]
 
@@ -349,26 +378,22 @@
 	return iwb_map
 
 
-def get_items(filters):
+def get_items(filters: StockBalanceFilter) -> List[str]:
 	"Get items based on item code, item group or brand."
-	conditions = []
-	if filters.get("item_code"):
-		conditions.append("item.name=%(item_code)s")
+	if item_code := filters.get("item_code"):
+		return [item_code]
 	else:
-		if filters.get("item_group"):
-			conditions.append(get_item_group_condition(filters.get("item_group")))
-		if filters.get("brand"):  # used in stock analytics report
-			conditions.append("item.brand=%(brand)s")
+		item_filters = {}
+		if item_group := filters.get("item_group"):
+			children = get_descendants_of("Item Group", item_group, ignore_permissions=True)
+			item_filters["item_group"] = ("in", children + [item_group])
+		if brand := filters.get("brand"):
+			item_filters["brand"] = brand
 
-	items = []
-	if conditions:
-		items = frappe.db.sql_list(
-			"""select name from `tabItem` item where {}""".format(" and ".join(conditions)), filters
-		)
-	return items
+		return frappe.get_all("Item", filters=item_filters, pluck="name", order_by=None)
 
 
-def get_item_details(items, sle, filters):
+def get_item_details(items: List[str], sle: List[SLEntry], filters: StockBalanceFilter):
 	item_details = {}
 	if not items:
 		items = list(set(d.item_code for d in sle))
@@ -376,33 +401,35 @@
 	if not items:
 		return item_details
 
-	cf_field = cf_join = ""
-	if filters.get("include_uom"):
-		cf_field = ", ucd.conversion_factor"
-		cf_join = (
-			"left join `tabUOM Conversion Detail` ucd on ucd.parent=item.name and ucd.uom=%s"
-			% frappe.db.escape(filters.get("include_uom"))
-		)
+	item_table = frappe.qb.DocType("Item")
 
-	res = frappe.db.sql(
-		"""
-		select
-			item.name, item.item_name, item.description, item.item_group, item.brand, item.stock_uom %s
-		from
-			`tabItem` item
-			%s
-		where
-			item.name in (%s)
-	"""
-		% (cf_field, cf_join, ",".join(["%s"] * len(items))),
-		items,
-		as_dict=1,
+	query = (
+		frappe.qb.from_(item_table)
+		.select(
+			item_table.name,
+			item_table.item_name,
+			item_table.description,
+			item_table.item_group,
+			item_table.brand,
+			item_table.stock_uom,
+		)
+		.where(item_table.name.isin(items))
 	)
 
-	for item in res:
-		item_details.setdefault(item.name, item)
+	if uom := filters.get("include_uom"):
+		uom_conv_detail = frappe.qb.DocType("UOM Conversion Detail")
+		query = (
+			query.left_join(uom_conv_detail)
+			.on((uom_conv_detail.parent == item_table.name) & (uom_conv_detail.uom == uom))
+			.select(uom_conv_detail.conversion_factor)
+		)
 
-	if filters.get("show_variant_attributes", 0) == 1:
+	result = query.run(as_dict=1)
+
+	for item_table in result:
+		item_details.setdefault(item_table.name, item_table)
+
+	if filters.get("show_variant_attributes"):
 		variant_values = get_variant_values_for(list(item_details))
 		item_details = {k: v.update(variant_values.get(k, {})) for k, v in item_details.items()}
 
@@ -413,36 +440,33 @@
 	item_reorder_details = frappe._dict()
 
 	if items:
-		item_reorder_details = frappe.db.sql(
-			"""
-			select parent, warehouse, warehouse_reorder_qty, warehouse_reorder_level
-			from `tabItem Reorder`
-			where parent in ({0})
-		""".format(
-				", ".join(frappe.db.escape(i, percent=False) for i in items)
-			),
-			as_dict=1,
+		item_reorder_details = frappe.get_all(
+			"Item Reorder",
+			["parent", "warehouse", "warehouse_reorder_qty", "warehouse_reorder_level"],
+			filters={"parent": ("in", items)},
 		)
 
 	return dict((d.parent + d.warehouse, d) for d in item_reorder_details)
 
 
-def get_variants_attributes():
+def get_variants_attributes() -> List[str]:
 	"""Return all item variant attributes."""
-	return [i.name for i in frappe.get_all("Item Attribute")]
+	return frappe.get_all("Item Attribute", pluck="name")
 
 
 def get_variant_values_for(items):
 	"""Returns variant values for items."""
 	attribute_map = {}
-	for attr in frappe.db.sql(
-		"""select parent, attribute, attribute_value
-		from `tabItem Variant Attribute` where parent in (%s)
-		"""
-		% ", ".join(["%s"] * len(items)),
-		tuple(items),
-		as_dict=1,
-	):
+
+	attribute_info = frappe.get_all(
+		"Item Variant Attribute",
+		["parent", "attribute", "attribute_value"],
+		{
+			"parent": ("in", items),
+		},
+	)
+
+	for attr in attribute_info:
 		attribute_map.setdefault(attr["parent"], {})
 		attribute_map[attr["parent"]].update({attr["attribute"]: attr["attribute_value"]})
 
diff --git a/erpnext/stock/report/stock_balance/test_stock_balance.py b/erpnext/stock/report/stock_balance/test_stock_balance.py
new file mode 100644
index 0000000..e963de2
--- /dev/null
+++ b/erpnext/stock/report/stock_balance/test_stock_balance.py
@@ -0,0 +1,174 @@
+from typing import Any, Dict
+
+import frappe
+from frappe import _dict
+from frappe.tests.utils import FrappeTestCase
+from frappe.utils import today
+
+from erpnext.stock.doctype.item.test_item import make_item
+from erpnext.stock.doctype.stock_entry.stock_entry_utils import make_stock_entry
+from erpnext.stock.report.stock_balance.stock_balance import execute
+
+
+def stock_balance(filters):
+	"""Get rows from stock balance report"""
+	return [_dict(row) for row in execute(filters)[1]]
+
+
+class TestStockBalance(FrappeTestCase):
+	# ----------- utils
+
+	def setUp(self):
+		self.item = make_item()
+		self.filters = _dict(
+			{
+				"company": "_Test Company",
+				"item_code": self.item.name,
+				"from_date": "2020-01-01",
+				"to_date": str(today()),
+			}
+		)
+
+	def tearDown(self):
+		frappe.db.rollback()
+
+	def assertPartialDictEq(self, expected: Dict[str, Any], actual: Dict[str, Any]):
+		for k, v in expected.items():
+			self.assertEqual(v, actual[k], msg=f"{expected=}\n{actual=}")
+
+	def generate_stock_ledger(self, item_code: str, movements):
+
+		for movement in map(_dict, movements):
+			if "to_warehouse" not in movement:
+				movement.to_warehouse = "_Test Warehouse - _TC"
+			make_stock_entry(item_code=item_code, **movement)
+
+	def assertInvariants(self, rows):
+		last_balance = frappe.db.sql(
+			"""
+			WITH last_balances AS (
+				SELECT item_code, warehouse,
+					stock_value, qty_after_transaction,
+					ROW_NUMBER() OVER (PARTITION BY item_code, warehouse
+						ORDER BY timestamp(posting_date, posting_time) desc, creation desc)
+						AS rn
+					FROM `tabStock Ledger Entry`
+					where is_cancelled=0
+				)
+				SELECT * FROM last_balances WHERE rn = 1""",
+			as_dict=True,
+		)
+
+		item_wh_stock = _dict()
+
+		for line in last_balance:
+			item_wh_stock.setdefault((line.item_code, line.warehouse), line)
+
+		for row in rows:
+			msg = f"Invariants not met for {rows=}"
+			# qty invariant
+			self.assertAlmostEqual(row.bal_qty, row.opening_qty + row.in_qty - row.out_qty, msg)
+
+			# value invariant
+			self.assertAlmostEqual(row.bal_val, row.opening_val + row.in_val - row.out_val, msg)
+
+			# check against SLE
+			last_sle = item_wh_stock[(row.item_code, row.warehouse)]
+			self.assertAlmostEqual(row.bal_qty, last_sle.qty_after_transaction, 3)
+			self.assertAlmostEqual(row.bal_val, last_sle.stock_value, 3)
+
+			# valuation rate
+			if not row.bal_qty:
+				continue
+			self.assertAlmostEqual(row.val_rate, row.bal_val / row.bal_qty, 3, msg)
+
+	# ----------- tests
+
+	def test_basic_stock_balance(self):
+		"""Check very basic functionality and item info"""
+		rows = stock_balance(self.filters)
+		self.assertEqual(rows, [])
+
+		self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10)])
+
+		# check item info
+		rows = stock_balance(self.filters)
+		self.assertPartialDictEq(
+			{
+				"item_code": self.item.name,
+				"item_name": self.item.item_name,
+				"item_group": self.item.item_group,
+				"stock_uom": self.item.stock_uom,
+				"in_qty": 5,
+				"in_val": 50,
+				"val_rate": 10,
+			},
+			rows[0],
+		)
+		self.assertInvariants(rows)
+
+	def test_opening_balance(self):
+		self.generate_stock_ledger(
+			self.item.name,
+			[
+				_dict(qty=1, rate=1, posting_date="2021-01-01"),
+				_dict(qty=2, rate=2, posting_date="2021-01-02"),
+				_dict(qty=3, rate=3, posting_date="2021-01-03"),
+			],
+		)
+		rows = stock_balance(self.filters)
+		self.assertInvariants(rows)
+
+		rows = stock_balance(self.filters.update({"from_date": "2021-01-02"}))
+		self.assertInvariants(rows)
+		self.assertPartialDictEq({"opening_qty": 1, "in_qty": 5}, rows[0])
+
+		rows = stock_balance(self.filters.update({"from_date": "2022-01-01"}))
+		self.assertInvariants(rows)
+		self.assertPartialDictEq({"opening_qty": 6, "in_qty": 0}, rows[0])
+
+	def test_uom_converted_info(self):
+
+		self.item.append("uoms", {"conversion_factor": 5, "uom": "Box"})
+		self.item.save()
+
+		self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10)])
+
+		rows = stock_balance(self.filters.update({"include_uom": "Box"}))
+		self.assertEqual(rows[0].bal_qty_alt, 1)
+		self.assertInvariants(rows)
+
+	def test_item_group(self):
+		self.filters.pop("item_code", None)
+		rows = stock_balance(self.filters.update({"item_group": self.item.item_group}))
+		self.assertTrue(all(r.item_group == self.item.item_group for r in rows))
+
+	def test_child_warehouse_balances(self):
+		# This is default
+		self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10, to_warehouse="Stores - _TC")])
+
+		self.filters.pop("item_code", None)
+		rows = stock_balance(self.filters.update({"warehouse": "All Warehouses - _TC"}))
+
+		self.assertTrue(
+			any(r.item_code == self.item.name and r.warehouse == "Stores - _TC" for r in rows),
+			msg=f"Expected child warehouse balances \n{rows}",
+		)
+
+	def test_show_item_attr(self):
+		from erpnext.controllers.item_variant import create_variant
+
+		self.item.has_variants = True
+		self.item.append("attributes", {"attribute": "Test Size"})
+		self.item.save()
+
+		attributes = {"Test Size": "Large"}
+		variant = create_variant(self.item.name, attributes)
+		variant.save()
+
+		self.generate_stock_ledger(variant.name, [_dict(qty=5, rate=10)])
+		rows = stock_balance(
+			self.filters.update({"show_variant_attributes": 1, "item_code": variant.name})
+		)
+		self.assertPartialDictEq(attributes, rows[0])
+		self.assertInvariants(rows)