package db
import (
"database/sql"
"git.sr.ht/~evanj/cms/internal/m/org"
"git.sr.ht/~evanj/cms/internal/m/tier"
"git.sr.ht/~evanj/cms/internal/m/user"
)
type Org struct {
OrgID string
OrgBillingTierName sql.NullString
}
func (o Org) ID() string { return o.OrgID }
func (o Org) Tier() tier.Tier {
t, ok := tier.ByName(o.OrgBillingTierName.String)
if !ok {
return tier.Free
}
return t
}
var (
queryOrgByID = `
SELECT cms_org.ID, cms_billing.TIER_NAME FROM cms_org
LEFT JOIN cms_billing ON cms_billing.ORG_ID=cms_org.ID
WHERE cms_org.ID=?
`
queryOrgByUserAndID = `
SELECT cms_org.ID, cms_billing.TIER_NAME FROM cms_org
LEFT JOIN cms_billing ON cms_billing.ORG_ID=cms_org.ID
JOIN cms_user ON cms_user.ORG_ID=cms_org.ID
WHERE cms_user.ID=? AND cms_org.ID=?
`
)
func (db *DB) OrgNew() (org.Org, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
org, err := db.orgNew(tx)
if err != nil {
return nil, err
}
return org, tx.Commit()
}
func (db *DB) orgNew(t *sql.Tx) (org.Org, error) {
res, err := t.Exec("INSERT INTO cms_org () VALUES ()")
if err != nil {
return nil, err
}
orgID, err := res.LastInsertId()
if err != nil {
return nil, err
}
var org Org
if err := t.QueryRow(queryOrgByID, orgID).Scan(&org.OrgID, &org.OrgBillingTierName); err != nil {
return nil, err
}
return org, nil
}
func (db *DB) OrgGet(u user.User, orgID string) (org.Org, error) {
var org Org
if err := db.QueryRow(queryOrgByUserAndID, u.ID(), orgID).Scan(&org.OrgID, &org.OrgBillingTierName); err != nil {
return nil, err
}
return org, nil
}
func (db *DB) OrgUpdateTier(u user.User, o org.Org, t tier.Tier, paymentCustomerID string) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec("DELETE FROM cms_billing WHERE ORG_ID=?", o.ID()); err != nil {
return err
}
if _, err := tx.Exec("INSERT INTO cms_billing (PAYMENT_CUSTOMER, TIER_NAME, ORG_ID) values(?, ?, ?)", paymentCustomerID, t.Name, o.ID()); err != nil {
return err
}
return tx.Commit()
}
func (db *DB) OrgGetSpaceCount(o org.Org) (int, error) {
var (
count int
q = "SELECT COUNT(*) FROM cms_space WHERE cms_space.ORG_ID=?"
)
if err := db.QueryRow(q, o.ID()).Scan(&count); err != nil {
return 0, err
}
return count, nil
}