refactor:
- current_invoice _start and end should be determined by trial period or billing period
- adds new functions to get billing period data
diff --git a/erpnext/accounts/doctype/subscriptions/subscriptions.py b/erpnext/accounts/doctype/subscriptions/subscriptions.py
index 6cd8934..0cccaeb 100644
--- a/erpnext/accounts/doctype/subscriptions/subscriptions.py
+++ b/erpnext/accounts/doctype/subscriptions/subscriptions.py
@@ -5,7 +5,7 @@
from __future__ import unicode_literals
import frappe
from frappe.model.document import Document
-from frappe.utils.data import now, nowdate, getdate, cint, add_days, date_diff, get_last_day, get_first_day
+from frappe.utils.data import now, nowdate, getdate, cint, add_days, date_diff, get_last_day, get_first_day, add_to_date
from frappe import _
@@ -25,17 +25,58 @@
if self.trial_period_start and self.is_trialling():
self.current_invoice_start = self.trial_period_start
elif not date:
- current_invoice = self.get_current_invoice()
- if not current_invoice:
- self.current_invoice_start = nowdate()
- else:
- self.current_invoice_start = current_invoice.posting_date
+ self.current_invoice_start = nowdate()
def set_current_invoice_end(self):
if self.is_trialling():
self.current_invoice_end = self.trial_period_end
else:
- self.current_invoice_end = get_last_day(self.current_invoice_start)
+ billing_cycle_info = self.get_billing_cycle()
+ if billing_cycle_info:
+ self.current_invoice_end = add_to_date(self.current_invoice_start, **billing_cycle_info)
+ else:
+ self.current_invoice_end = get_last_day(self.current_invoice_start)
+
+ def get_billing_cycle(self):
+ return self.get_billing_cycle_data()
+
+ def validate_plans_billing_cycle(self, billing_cycle_data):
+ if billing_cycle_data and len(billing_cycle_data) != 1:
+ frappe.throw(_('You can only have Plans with the same billing cycle in a Subscription'))
+
+ def get_billing_cycle_and_interval(self):
+ plan_names = [plan.plan for plan in self.plans]
+ billing_info = frappe.db.sql(
+ 'select distinct `billing_interval`, `billing_interval_count` '
+ 'from `tabSubscription Plan` '
+ 'where name in %s',
+ (plan_names,), as_dict=1
+ )
+
+ return billing_info
+
+ def get_billing_cycle_data(self):
+ billing_info = self.get_billing_cycle_and_interval()
+
+ self.validate_plans_billing_cycle(billing_info)
+
+ if billing_info:
+ data = dict()
+ interval = billing_info[0]['billing_interval']
+ interval_count = billing_info[0]['billing_interval_count']
+ if interval not in ['Day', 'Week']:
+ data['days'] = -1
+ if interval == 'Day':
+ data['days'] = interval_count - 1
+ elif interval == 'Month':
+ data['months'] = interval_count
+ elif interval == 'Year':
+ data['years'] == interval_count
+ # todo: test week
+ elif interval == 'Week':
+ data['days'] = interval_count * 7 - 1
+
+ return data
def before_save(self):
self.set_status()
@@ -89,6 +130,7 @@
def validate(self):
self.validate_trial_period()
+ self.validate_plans_billing_cycle(self.get_billing_cycle_and_interval())
def validate_trial_period(self):
if self.trial_period_start and self.trial_period_end: