Salary Slip - update tax calculation, deduct tax for unsubmitted exemption proof (#14428)

diff --git a/erpnext/hr/doctype/salary_slip/salary_slip.py b/erpnext/hr/doctype/salary_slip/salary_slip.py
index dbd73b5..b8fadfd 100644
--- a/erpnext/hr/doctype/salary_slip/salary_slip.py
+++ b/erpnext/hr/doctype/salary_slip/salary_slip.py
@@ -108,7 +108,9 @@
 				'salary_component' : struct_row.salary_component,
 				'abbr' : struct_row.abbr,
 				'do_not_include_in_total' : struct_row.do_not_include_in_total,
-				'is_flexible_benefit': struct_row.is_flexible_benefit
+				'is_tax_applicable': struct_row.is_tax_applicable,
+				'is_flexible_benefit': struct_row.is_flexible_benefit,
+				'variable_based_on_taxable_salary': struct_row.variable_based_on_taxable_salary
 			})
 		else:
 			component_row.amount = amount
@@ -499,25 +501,35 @@
 		return status
 
 	def calculate_variable_based_on_taxable_salary(self, tax_component):
-		# TODO case both checked - restrict to and make this mandatory on final period of payroll?
-		# case only deduct_tax_for_unsubmitted_tax_exemption_proof checked not handled, calculate_variable_tax called
 		payroll_period = get_payroll_period(self.start_date, self.end_date, self.company)
 		if not payroll_period:
 			frappe.msgprint(_("Start and end dates not in a valid Payroll Period, \
 			cannot calculate {0}.").format(tax_component))
 			return False, False
-		if self.deduct_tax_for_unclaimed_employee_benefits and not self.deduct_tax_for_unsubmitted_tax_exemption_proof:
-			total_taxable_benefit = self.calculate_unclaimed_benefit_amount(payroll_period)
-			total_taxable_benefit += self.get_taxable_earnings(only_flexi=True)
-			return self.calculate_variable_tax(tax_component, payroll_period, benefit_amount=total_taxable_benefit)
-		elif self.deduct_tax_for_unclaimed_employee_benefits and self.deduct_tax_for_unsubmitted_tax_exemption_proof:
-			return self.calculate_tax_for_payroll_period(tax_component, payroll_period)
-		else:
-			return self.calculate_variable_tax(tax_component, payroll_period)
+		if payroll_period.end_date <= getdate(self.end_date):
+			if not self.deduct_tax_for_unsubmitted_tax_exemption_proof \
+				or not self.deduct_tax_for_unclaimed_employee_benefits:
+				frappe.throw(_("You have to Deduct Tax for Unsubmitted Tax Exemption Proof \
+				and Unclaimed Employee Benefits in the last Salary Slip of Payroll Period"))
+			else:
+				return self.calculate_tax_for_payroll_period(tax_component, payroll_period)
 
-	def calculate_variable_tax(self, tax_component, payroll_period, benefit_amount=0):
+		benefit_amount_to_tax = 0
+		if self.deduct_tax_for_unclaimed_employee_benefits:
+			# get all untaxed benefits till date, pass amount to be taxed by later methods
+			benefit_amount_to_tax = self.calculate_unclaimed_taxable_benefit(payroll_period)
+			# flexi's excluded from monthly tax, add flexis in this slip to total_taxable_benefit
+			benefit_amount_to_tax += self.get_taxable_earnings(only_flexi=True)
+		if self.deduct_tax_for_unsubmitted_tax_exemption_proof:
+			# calc tax to be paid for the period till date considering prorata taxes paid and proofs submitted
+			return self.calculate_unclaimed_taxable_earning(payroll_period, tax_component, benefit_amount_to_tax)
+
+		# calc prorata tax to be applied
+		return self.calculate_variable_tax(tax_component, payroll_period, benefit_amount_to_tax=benefit_amount_to_tax)
+
+	def calculate_variable_tax(self, tax_component, payroll_period, benefit_amount_to_tax=0):
 		total_taxable_earning = self.get_taxable_earnings()
-		period_factor = self.get_period_factor(payroll_period.start_date, payroll_period.end_date)
+		period_factor = self.get_period_factor(payroll_period.start_date, payroll_period.end_date, self.start_date, self.end_date)
 		annual_earning = total_taxable_earning * period_factor
 
 		# Calculate total exemption declaration
@@ -529,18 +541,7 @@
 				"total_exemption_amount")
 		annual_taxable_earning = annual_earning - exemption_amount
 
-		# Get tax calc by period
-		annual_tax = self.calculate_tax(payroll_period.name, annual_taxable_earning)
-
-		# Calc prorata tax
-		pro_rata_tax = annual_tax / period_factor
-		struct_row = self.get_salary_slip_row(tax_component)
-
-		# find the annual tax diff caused by benefit, add to pro_rata_tax
-		if benefit_amount > 0:
-			annual_tax_with_benefit = self.calculate_tax(payroll_period.name, annual_taxable_earning + benefit_amount)
-			pro_rata_tax += annual_tax_with_benefit - annual_tax
-		return struct_row, pro_rata_tax
+		return self.calculate_tax(payroll_period, tax_component, annual_taxable_earning, period_factor, 0, benefit_amount_to_tax)
 
 	def calculate_tax_for_payroll_period(self, tax_component, payroll_period):
 		# get total taxable income, total tax paid in payroll period
@@ -560,18 +561,15 @@
 			if sum_benefit_claim and sum_benefit_claim[0][0]:
 				total_benefit_claim = sum_benefit_claim[0][0]
 		total_taxable_earning = taxable_income - total_tax_exemption_proof - total_benefit_claim
+
 		# add taxable earnings of current salary_slip, include flexi
 		total_taxable_earning += self.get_taxable_earnings(include_flexi=1)
-		# calc annual tax by tax slab
-		annual_tax = self.calculate_tax(payroll_period.name, total_taxable_earning)
-		# get balance amount to tax, even if -ve add to deduction
-		pay_slip_tax = annual_tax - tax_paid
-		struct_row = self.get_salary_slip_row(tax_component)
-		return struct_row, pay_slip_tax
+		return self.calculate_tax(payroll_period, tax_component, total_taxable_earning, 1, tax_paid, 0)
 
-	def calculate_unclaimed_benefit_amount(self, payroll_period):
+	def calculate_unclaimed_taxable_benefit(self, payroll_period):
 		total_benefit = 0
 		start_date = payroll_period.start_date
+
 		# if tax for unclaimed benefit deducted earlier set the start date
 		last_deducted =	frappe.db.sql("""select end_date from `tabSalary Slip` where
 				deduct_tax_for_unclaimed_employee_benefits=1 and docstatus=1 and
@@ -580,6 +578,8 @@
 				self.employee, payroll_period.start_date, payroll_period.end_date))
 		if last_deducted and last_deducted[0][0]:
 			start_date = getdate(last_deducted[0][0])
+
+		# get total sum of benefits paid
 		sum_benefit = frappe.db.sql("""select sum(sd.amount) from `tabSalary Detail` sd join
 					`tabSalary Slip` ss on sd.parent=ss.name where sd.parentfield='earnings'
 					and sd.is_tax_applicable=1 and is_flexible_benefit=1 and ss.docstatus=1
@@ -588,6 +588,8 @@
 					start_date, payroll_period.end_date))
 		if sum_benefit and sum_benefit[0][0]:
 			total_benefit = sum_benefit[0][0]
+
+		# get total benefits claimed
 		total_benefit_claim = 0
 		sum_benefit_claim = frappe.db.sql("""select sum(claimed_amount) from
 		`tabEmployee Benefit Claim` where docstatus=1 and employee='{0}' and claim_date
@@ -596,28 +598,103 @@
 			total_benefit_claim = sum_benefit_claim[0][0]
 		return total_benefit - total_benefit_claim
 
+	def calculate_unclaimed_taxable_earning(self, payroll_period, tax_component, benefit_amount_to_tax):
+		total_taxable_earning, total_tax_paid = 0, 0
+		start_date = payroll_period.start_date
+
+		# if tax deducted earlier set the start date
+		last_deducted =	frappe.db.sql("""select end_date from `tabSalary Slip` where
+				deduct_tax_for_unsubmitted_tax_exemption_proof=1 and docstatus=1 and
+				employee='{0}' and start_date between '{1}' and '{2}' and end_date
+				between '{1}' and '{2}' order by end_date desc limit 1""".format(
+				self.employee, payroll_period.start_date, self.start_date))
+		if last_deducted and last_deducted[0][0]:
+			start_date = getdate(last_deducted[0][0])
+
+		# calc total taxable amount in period
+		sum_taxable_earning = frappe.db.sql("""select sum(sd.amount) from `tabSalary Detail` sd join
+					`tabSalary Slip` ss on sd.parent=ss.name where sd.parentfield='earnings'
+					and sd.is_tax_applicable=1 and is_flexible_benefit=0 and ss.docstatus=1
+					and ss.employee='{0}' and ss.start_date between '{1}' and '{2}' and
+					ss.end_date between '{1}' and '{2}'""".format(self.employee,
+					start_date, self.start_date))
+		if sum_taxable_earning and sum_taxable_earning[0][0]:
+			total_taxable_earning = sum_taxable_earning[0][0]
+
+		# add taxable earning in this salary slip
+		total_taxable_earning += self.get_taxable_earnings()
+
+		# find total_tax_paid from salary slip where benefit is not taxed
+		sum_tax_paid = frappe.db.sql("""select sum(sd.amount) from `tabSalary Detail` sd join
+					`tabSalary Slip` ss on sd.parent=ss.name where sd.parentfield='deductions'
+					and sd.salary_component='{3}' and sd.variable_based_on_taxable_salary=1 and ss.docstatus=1
+					and ss.employee='{0}' and ss.deduct_tax_for_unclaimed_employee_benefits=0
+					and ss.start_date between '{1}' and '{2}' and ss.end_date between '{1}' and
+					'{2}'""".format(self.employee, start_date, self.start_date, tax_component))
+		if sum_tax_paid and sum_tax_paid[0][0]:
+			total_tax_paid = sum_tax_paid[0][0]
+
+		# get benefit taxed salary slips
+		benefit_taxed_ss = frappe.db.sql("""select name from `tabSalary Slip` where
+					deduct_tax_for_unsubmitted_tax_exemption_proof=0 and
+					deduct_tax_for_unclaimed_employee_benefits=1 and docstatus=1 and employee='{0}'
+					and start_date between '{1}' and '{2}' and end_date between '{1}'
+					and '{2}'""".format(self.employee, start_date, self.start_date))
+		# add pro_rata_tax of all salary slips where benefit tax added up
+		if benefit_taxed_ss and benefit_taxed_ss[0]:
+			for salary_slip in benefit_taxed_ss[0]:
+				ss_obj = frappe.get_doc("Salary Slip", salary_slip)
+				struct_row, pro_rata_tax = ss_obj.calculate_variable_tax(tax_component, payroll_period)
+				if pro_rata_tax:
+					total_tax_paid += pro_rata_tax
+		total_exemption_amount = 0
+
+		# add up total Proof Submission
+		sum_exemption = frappe.db.sql("""select sum(total_amount) from
+		`tabEmployee Tax Exemption Proof Submission` where docstatus=1 and employee='{0}' and
+		payroll_period='{1}' and processed_in_payroll=0""".format(self.employee, payroll_period.name))
+		if sum_exemption and sum_exemption[0][0]:
+			total_exemption_amount = sum_exemption[0][0]
+		total_taxable_earning -= total_exemption_amount
+
+		# recalc annual tax slab by start date and end date
+		period_factor = self.get_period_factor(payroll_period.start_date, payroll_period.end_date, start_date, self.end_date)
+		annual_taxable_earning = total_taxable_earning * period_factor
+		return self.calculate_tax(payroll_period, tax_component, annual_taxable_earning, period_factor, total_tax_paid, benefit_amount_to_tax)
+
 	def get_taxable_earnings(self, include_flexi=0, only_flexi=0):
-		# TODO remove this, iterate in self.earnings. map_doc fails to copy field values from Salary Structure to Slary Slip
-		tax_applicable_components = []
-		for earning in self._salary_structure_doc.earnings:
+		taxable_earning = 0
+		for earning in self.earnings:
 			if only_flexi:
 				if earning.is_tax_applicable and earning.is_flexible_benefit:
-					tax_applicable_components.append(earning.salary_component)
+					taxable_earning += earning.amount
 				continue
 			if include_flexi:
 				if earning.is_tax_applicable or (earning.is_tax_applicable and earning.is_flexible_benefit):
-					tax_applicable_components.append(earning.salary_component)
+					taxable_earning += earning.amount
 			else:
 				if earning.is_tax_applicable and not earning.is_flexible_benefit:
-					tax_applicable_components.append(earning.salary_component)
-
-		taxable_earning = 0
-		for earning in self.earnings:
-			if earning.salary_component in tax_applicable_components:
-				taxable_earning += earning.amount
+					taxable_earning += earning.amount
 		return taxable_earning
 
-	def calculate_tax(self, payroll_period, annual_earning):
+	def calculate_tax(self, payroll_period, tax_component, annual_taxable_earning, period_factor, tax_paid=0, benefit_amount_to_tax=0):
+		# Get tax calc by period
+		annual_tax = self.calculate_tax_by_tax_slab(payroll_period.name, annual_taxable_earning)
+
+		# Calc prorata tax
+		tax_amount = annual_tax / period_factor
+		if tax_paid:
+			tax_amount -= tax_paid
+
+		# find the annual tax diff caused by benefit_amount_to_tax, add to tax_amount
+		if benefit_amount_to_tax > 0:
+			annual_tax_with_benefit_amt = self.calculate_tax_by_tax_slab(payroll_period.name, annual_taxable_earning + benefit_amount_to_tax)
+			tax_amount += annual_tax_with_benefit_amt - annual_tax
+		struct_row = self.get_salary_slip_row(tax_component)
+		return struct_row, tax_amount
+
+	def calculate_tax_by_tax_slab(self, payroll_period, annual_earning):
+		# TODO consider condition in tax slab
 		payroll_period_obj = frappe.get_doc("Payroll Period", payroll_period)
 		taxable_amount = 0
 		for slab in payroll_period_obj.taxable_salary_slabs:
@@ -627,13 +704,17 @@
 				taxable_amount += (slab.to_amount - slab.from_amount) * slab.percent_deduction * .01
 		return taxable_amount
 
-	def get_period_factor(self, start_date, end_date):
-		# period length is hard coded to keep tax calc consistent
+	def get_period_factor(self, period_start, period_end, start_date=None, end_date=None):
+		# TODO make this configurable? - use hard coded period length to keep tax calc consistent
 		frequency_days = {"Daily": 1, "Weekly": 7, "Fortnightly": 15, "Monthly": 30, "Bimonthly": 60}
-		payroll_days = date_diff(end_date, start_date) + 1
+		payroll_days = date_diff(period_end, period_start) + 1
+		if start_date and end_date:
+			salary_days = date_diff(end_date, start_date) +1
+			return flt(payroll_days)/flt(salary_days)
 		return flt(payroll_days)/frequency_days[self.payroll_frequency]
 
 	def get_tax_detail_till_date(self, payroll_period, tax_component):
+		# find total taxable income, total tax paid by employee in payroll period
 		total_taxable_income = 0
 		total_tax_paid = 0
 		sum_income = frappe.db.sql("""select sum(sd.amount) from `tabSalary Detail` sd join
@@ -662,6 +743,9 @@
 		struct_row['salary_component'] = component.name
 		struct_row['abbr'] = component.salary_component_abbr
 		struct_row['do_not_include_in_total'] = component.do_not_include_in_total
+		struct_row['is_tax_applicable'] = component.is_tax_applicable
+		struct_row['is_flexible_benefit'] = component.is_flexible_benefit
+		struct_row['variable_based_on_taxable_salary'] = component.variable_based_on_taxable_salary
 		return struct_row
 
 def unlink_ref_doc_from_salary_slip(ref_no):