refactor: rewrite stock balance query to QB
diff --git a/erpnext/stock/report/stock_balance/stock_balance.py b/erpnext/stock/report/stock_balance/stock_balance.py
index a5136a3..7fa417a 100644
--- a/erpnext/stock/report/stock_balance/stock_balance.py
+++ b/erpnext/stock/report/stock_balance/stock_balance.py
@@ -7,11 +7,13 @@
import frappe
from frappe import _
+from frappe.query_builder.functions import CombineDatetime
from frappe.utils import cint, date_diff, flt, getdate
+from frappe.utils.nestedset import get_descendants_of
+from pypika.terms import ExistsCriterion
import erpnext
from erpnext.stock.report.stock_ageing.stock_ageing import FIFOSlots, get_average_age
-from erpnext.stock.report.stock_ledger.stock_ledger import get_item_group_condition
from erpnext.stock.utils import add_additional_uom_columns, is_reposting_item_valuation_in_progress
@@ -33,8 +35,6 @@
if not filters:
filters = {}
- to_date = filters.get("to_date")
-
if filters.get("company"):
company_currency = erpnext.get_company_currency(filters.get("company"))
else:
@@ -62,6 +62,7 @@
_func = itemgetter(1)
+ to_date = filters.get("to_date")
for (company, item, warehouse) in sorted(iwb_map):
if item_map.get(item):
qty_dict = iwb_map[(company, item, warehouse)]
@@ -229,64 +230,75 @@
return columns
-def get_conditions(filters: StockBalanceFilter):
- conditions = ""
+def apply_conditions(query, filters):
+ sle = frappe.qb.DocType("Stock Ledger Entry")
+ warehouse_table = frappe.qb.DocType("Warehouse")
+
if not filters.get("from_date"):
frappe.throw(_("'From Date' is required"))
- if filters.get("to_date"):
- conditions += " and sle.posting_date <= %s" % frappe.db.escape(filters.get("to_date"))
+ if to_date := filters.get("to_date"):
+ query = query.where(sle.posting_date <= to_date)
else:
frappe.throw(_("'To Date' is required"))
- if filters.get("company"):
- conditions += " and sle.company = %s" % frappe.db.escape(filters.get("company"))
+ if company := filters.get("company"):
+ query = query.where(sle.company == company)
- if filters.get("warehouse"):
- warehouse_details = frappe.db.get_value(
- "Warehouse", filters.get("warehouse"), ["lft", "rgt"], as_dict=1
- )
- if warehouse_details:
- conditions += (
- " and exists (select name from `tabWarehouse` wh \
- where wh.lft >= %s and wh.rgt <= %s and sle.warehouse = wh.name)"
- % (warehouse_details.lft, warehouse_details.rgt)
+ if warehouse := filters.get("warehouse"):
+ lft, rgt = frappe.db.get_value("Warehouse", warehouse, ["lft", "rgt"])
+ chilren_subquery = (
+ frappe.qb.from_(warehouse_table)
+ .select(warehouse_table.name)
+ .where(
+ (warehouse_table.lft >= lft)
+ & (warehouse_table.rgt <= rgt)
+ & (warehouse_table.name == sle.warehouse)
)
-
- if filters.get("warehouse_type") and not filters.get("warehouse"):
- conditions += (
- " and exists (select name from `tabWarehouse` wh \
- where wh.warehouse_type = '%s' and sle.warehouse = wh.name)"
- % (filters.get("warehouse_type"))
+ )
+ query = query.where(ExistsCriterion(chilren_subquery))
+ elif warehouse_type := filters.get("warehouse_type"):
+ query = (
+ query.join(warehouse_table)
+ .on(warehouse_table.name == sle.warehouse)
+ .where(warehouse_table.warehouse_type == warehouse_type)
)
- return conditions
+ return query
def get_stock_ledger_entries(filters: StockBalanceFilter, items):
- item_conditions_sql = ""
- if items:
- item_conditions_sql = " and sle.item_code in ({})".format(
- ", ".join(frappe.db.escape(i, percent=False) for i in items)
+ sle = frappe.qb.DocType("Stock Ledger Entry")
+
+ query = (
+ frappe.qb.from_(sle)
+ .select(
+ sle.item_code,
+ sle.warehouse,
+ sle.posting_date,
+ sle.actual_qty,
+ sle.valuation_rate,
+ sle.company,
+ sle.voucher_type,
+ sle.qty_after_transaction,
+ sle.stock_value_difference,
+ sle.item_code.as_("name"),
+ sle.voucher_no,
+ sle.stock_value,
+ sle.batch_no,
)
-
- conditions = get_conditions(filters)
-
- return frappe.db.sql(
- """
- select
- sle.item_code, warehouse, sle.posting_date, sle.actual_qty, sle.valuation_rate,
- sle.company, sle.voucher_type, sle.qty_after_transaction, sle.stock_value_difference,
- sle.item_code as name, sle.voucher_no, sle.stock_value, sle.batch_no
- from
- `tabStock Ledger Entry` sle
- where sle.docstatus < 2 %s %s
- and is_cancelled = 0
- order by sle.posting_date, sle.posting_time, sle.creation, sle.actual_qty"""
- % (item_conditions_sql, conditions), # nosec
- as_dict=1,
+ .where((sle.docstatus < 2) & (sle.is_cancelled == 0))
+ .orderby(CombineDatetime(sle.posting_date, sle.posting_time))
+ .orderby(sle.creation)
+ .orderby(sle.actual_qty)
)
+ if items:
+ query = query.where(sle.item_code.isin(items))
+
+ query = apply_conditions(query, filters)
+ return query.run(as_dict=True)
+
def get_item_warehouse_map(filters: StockBalanceFilter, sle):
iwb_map = {}
@@ -365,21 +377,17 @@
def get_items(filters: StockBalanceFilter):
"Get items based on item code, item group or brand."
- conditions = []
- if filters.get("item_code"):
- conditions.append("item.name=%(item_code)s")
+ if item_code := filters.get("item_code"):
+ return [item_code]
else:
- if filters.get("item_group"):
- conditions.append(get_item_group_condition(filters.get("item_group")))
- if filters.get("brand"): # used in stock analytics report
- conditions.append("item.brand=%(brand)s")
+ item_filters = {}
+ if item_group := filters.get("item_group"):
+ children = get_descendants_of("Item Group", item_group, ignore_permissions=True)
+ item_filters["item_group"] = ("in", children + [item_group])
+ if brand := filters.get("brand"):
+ item_filters["brand"] = brand
- items = []
- if conditions:
- items = frappe.db.sql_list(
- """select name from `tabItem` item where {}""".format(" and ".join(conditions)), filters
- )
- return items
+ return frappe.get_all("Item", filters=item_filters, pluck="name", order_by=None, debug=1)
def get_item_details(items, sle, filters: StockBalanceFilter):
@@ -416,7 +424,7 @@
for item in res:
item_details.setdefault(item.name, item)
- if filters.get("show_variant_attributes", 0) == 1:
+ if filters.get("show_variant_attributes"):
variant_values = get_variant_values_for(list(item_details))
item_details = {k: v.update(variant_values.get(k, {})) for k, v in item_details.items()}
@@ -443,7 +451,7 @@
def get_variants_attributes():
"""Return all item variant attributes."""
- return [i.name for i in frappe.get_all("Item Attribute")]
+ return frappe.get_all("Item Attribute", pluck="name")
def get_variant_values_for(items):
diff --git a/erpnext/stock/report/stock_balance/test_stock_balance.py b/erpnext/stock/report/stock_balance/test_stock_balance.py
index 2783d27..111bdc9 100644
--- a/erpnext/stock/report/stock_balance/test_stock_balance.py
+++ b/erpnext/stock/report/stock_balance/test_stock_balance.py
@@ -38,11 +38,9 @@
def generate_stock_ledger(self, item_code: str, movements):
for movement in map(_dict, movements):
- make_stock_entry(
- item_code=item_code,
- **movement,
- to_warehouse=movement.to_warehouse or "_Test Warehouse - _TC",
- )
+ if "to_warehouse" not in movement:
+ movement.to_warehouse = "_Test Warehouse - _TC"
+ make_stock_entry(item_code=item_code, **movement)
def assertInvariants(self, rows):
last_balance = frappe.db.sql(
@@ -135,3 +133,20 @@
rows = stock_balance(self.filters.update({"include_uom": "Box"}))
self.assertEqual(rows[0].bal_qty_alt, 1)
+
+ def test_item_group(self):
+ self.filters.pop("item_code", None)
+ rows = stock_balance(self.filters.update({"item_group": self.item.item_group}))
+ self.assertTrue(all(r.item_group == self.item.item_group for r in rows))
+
+ def test_child_warehouse_balances(self):
+ # This is default
+ self.generate_stock_ledger(self.item.name, [_dict(qty=5, rate=10, to_warehouse="Stores - _TC")])
+
+ self.filters.pop("item_code", None)
+ rows = stock_balance(self.filters.update({"warehouse": "All Warehouses - _TC"}))
+
+ self.assertTrue(
+ any(r.item_code == self.item.name and r.warehouse == "Stores - _TC" for r in rows),
+ msg=f"Expected child warehouse balances \n{rows}",
+ )