~evanj/cms

9d2dfb0b32eb388682adedc84de30562f01806be — Evan J 2 months ago 4e6befb
Feat(db): Mostly complete removing potentially deadlocking DB code. Only
db.Exec calls left to remove.
M TODO => TODO +2 -0
@@ 1,4 1,6 @@
[high]
db.SetMaxOpenConns(1): test for deadlocks in DB code (remove all db.Query and 
  db.Exec in favor of t.Query and t.Exec)
Testing: 100% happy path and 80% total
Documentation
Doc pages: Contact, FAQ, Terms, Privacy, Tour

M internal/s/db/action.go => internal/s/db/action.go +17 -1
@@ 1,6 1,7 @@
package db

import (
	"database/sql"
	"time"

	"git.sr.ht/~evanj/cms/internal/m/org"


@@ 12,6 13,21 @@ func (db *DB) ActionNew(o org.Org, at time.Time) error {
}

func (db *DB) ActionGetCount(o org.Org, from, to time.Time) (int, error) {
	t, err := db.Begin()
	if err != nil {
		return 0, err
	}
	defer t.Rollback()

	i, err := db.actionGetCount(t, o, from, to)
	if err != nil {
		return 0, err
	}

	return i, t.Commit()
}

func (db *DB) actionGetCount(t *sql.Tx, o org.Org, from, to time.Time) (int, error) {
	var (
		count int
		q     = "SELECT COUNT(*) FROM cms_action WHERE cms_action.ORG_ID=? AND AT>? AND AT<?"


@@ 19,7 35,7 @@ func (db *DB) ActionGetCount(o org.Org, from, to time.Time) (int, error) {

	a := from.Format("2006-01-02 03:04:05")
	b := to.Format("2006-01-02 03:04:05")
	if err := db.QueryRow(q, o.ID(), a, b).Scan(&count); err != nil {
	if err := t.QueryRow(q, o.ID(), a, b).Scan(&count); err != nil {
		return 0, err
	}


M internal/s/db/content.go => internal/s/db/content.go +64 -19
@@ 532,7 532,7 @@ type ContentRefSet struct {
	ContentTypeID, ContentID string
}

func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefSet, err error) {
func (db *DB) contentRefererList(t *sql.Tx, contentID string, depth int) (ret []ContentRefSet, err error) {
	// Cap total recursion to defaultDepth
	if depth < 1 {
		return ret, nil


@@ 559,19 559,27 @@ func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefS
		AND cms_value_reference_list_values.CONTENT_ID = ?
	`

	rows, err := db.Query(refQ, valuetype.Reference, contentID)
	rows, err := t.Query(refQ, valuetype.Reference, contentID)
	if err != nil {
		return ret, err
	}
	defer rows.Close()

	var refs []ContentRefSet
	for rows.Next() {
		var ref ContentRefSet
		if err := rows.Scan(&ref.ContentTypeID, &ref.ContentID); err != nil {
			return ret, err
		}

		nested, err := db.contentRefererList(ref.ContentID, depth)
		refs = append(refs, ref)
	}

	if err := rows.Close(); err != nil {
		return ret, err
	}

	for _, ref := range refs {
		nested, err := db.contentRefererList(t, ref.ContentID, depth)
		if err != nil {
			return nil, err
		}


@@ 580,19 588,27 @@ func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefS
		ret = append(ret, nested...)
	}

	rows, err = db.Query(refListQ, valuetype.ReferenceList, contentID)
	rows, err = t.Query(refListQ, valuetype.ReferenceList, contentID)
	if err != nil {
		return ret, err
	}
	defer rows.Close()

	refs = refs[0:0] // Empty slice.
	for rows.Next() {
		var ref ContentRefSet
		if err := rows.Scan(&ref.ContentTypeID, &ref.ContentID); err != nil {
			return ret, err
		}

		nested, err := db.contentRefererList(ref.ContentID, depth)
		refs = append(refs, ref)
	}

	if err := rows.Close(); err != nil {
		return ret, err
	}

	for _, ref := range refs {
		nested, err := db.contentRefererList(t, ref.ContentID, depth)
		if err != nil {
			return nil, err
		}


@@ 606,7 622,18 @@ func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefS

// ContentRefererList will retreive all content IDs that references a given piece of content.
func (db *DB) ContentRefererList(contentID string) (ret []ContentRefSet, err error) {
	return db.contentRefererList(contentID, defaultDepth)
	t, err := db.Begin()
	if err != nil {
		return ret, err
	}
	defer t.Rollback()

	list, err := db.contentRefererList(t, contentID, defaultDepth)
	if err != nil {
		return ret, err
	}

	return list, t.Commit()
}

func (db *DB) contentUpdate(t *sql.Tx, space space.Space, ct contenttype.ContentType, content content.Content, newParams []ContentNewParam, updateParams []ContentUpdateParam) error {


@@ 847,8 874,8 @@ func (db *DB) contentPerContentType(t *sql.Tx, u user.User, space space.Space, c
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var ids []string
	for i := 0; rows.Next(); i++ {
		if i == perPage {
			hasMore = true


@@ 859,7 886,15 @@ func (db *DB) contentPerContentType(t *sql.Tx, u user.User, space space.Space, c
			return nil, err
		}

		c, err := db.ContentGet(u, space, ct, tmpContentID)
		ids = append(ids, tmpContentID)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, id := range ids {
		c, err := db.contentGet(t, space, ct, id, defaultDepth)
		if err != nil {
			return nil, err
		}


@@ 951,9 986,8 @@ func (db *DB) ContentSearch(u user.User, space space.Space, ct contenttype.Conte
	if err != nil {
		return nil, err
	}
	// Handled below.
	// defer rows.Close()

	var ids []string
	for i := 0; rows.Next(); i++ {
		if i == perPage {
			hasMore = true


@@ 964,7 998,15 @@ func (db *DB) ContentSearch(u user.User, space space.Space, ct contenttype.Conte
			return nil, err
		}

		c, err := db.ContentGet(u, space, ct, tmpContentID)
		ids = append(ids, tmpContentID)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, tmpContentID := range ids {
		c, err := db.contentGet(t, space, ct, tmpContentID, defaultDepth)
		if err != nil {
			return nil, err
		}


@@ 972,10 1014,6 @@ func (db *DB) ContentSearch(u user.User, space space.Space, ct contenttype.Conte
		r = append(r, c)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	return newContentList(r, hasMore, tmpID), t.Commit()
}



@@ 987,18 1025,25 @@ func (db *DB) contentGet(t *sql.Tx, space space.Space, ct contenttype.ContentTyp

	// TODO: For some reason t.Query(...) is causing errors here.
	// See: https://github.com/go-sql-driver/mysql/issues/314
	rows, err := db.Query(queryValueListByContent, content.ID(), content.ID(), content.ID(), content.ID(), content.ID())
	rows, err := t.Query(queryValueListByContent, content.ID(), content.ID(), content.ID(), content.ID(), content.ID())
	if err != nil {
		return nil, fmt.Errorf("failed to find value(s)")
	}
	defer rows.Close()

	var values []ContentValue
	for rows.Next() {
		var value ContentValue
		if err := rows.Scan(&value.FieldID, &value.FieldType, &value.FieldName, &value.FieldValue); err != nil {
			return nil, fmt.Errorf("failed to scan values(s)")
		}
		values = append(values, value)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, value := range values {
		if err := db.contentValueAttachRef(t, &value, depth); err != nil {
			return nil, err
		}

M internal/s/db/contenttype.go => internal/s/db/contenttype.go +46 -3
@@ 86,6 86,7 @@ func (db *DB) ContentTypeNew(u user.User, space space.Space, name string, params
	}

	for _, item := range params {
		db.log.Println(item, id)
		if _, err := t.Exec(queryCreateContentTypeConnection, item.Name, id, item.Type); err != nil {
			return nil, fmt.Errorf("failed to create field(s)")
		}


@@ 108,6 109,10 @@ func (db *DB) ContentTypeNew(u user.User, space space.Space, name string, params
		ct.ContentTypeFields = append(ct.ContentTypeFields, field)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	if len(ct.ContentTypeFields) != len(params) {
		return nil, fmt.Errorf("failed to create all fields")
	}


@@ 184,11 189,12 @@ func (db *DB) contentTypesPerSpace(t *sql.Tx, space space.Space, before int) (co
		ORDER BY ID DESC LIMIT ?
	`

	rows, err := db.Query(q, space.ID(), before, perPage+1)
	rows, err := t.Query(q, space.ID(), before, perPage+1)
	if err != nil {
		return nil, err
	}

	var ids []int
	for i := 0; rows.Next(); i++ {
		if i == perPage {
			hasMore = true


@@ 199,6 205,14 @@ func (db *DB) contentTypesPerSpace(t *sql.Tx, space space.Space, before int) (co
			return nil, err
		}

		ids = append(ids, id)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, id := range ids {
		ct, err := db.contentTypeGet(t, space, strconv.Itoa(id))
		if err != nil {
			return nil, err


@@ 235,6 249,7 @@ func (db *DB) contentTypeGet(t *sql.Tx, space space.Space, contenttypeID string)
	if err != nil {
		return nil, fmt.Errorf("failed to find field(s)")
	}

	for rows.Next() {
		var field ContentTypeField
		if err := rows.Scan(&field.FieldID, &field.FieldName, &field.FieldType); err != nil {


@@ 243,6 258,10 @@ func (db *DB) contentTypeGet(t *sql.Tx, space space.Space, contenttypeID string)
		ct.ContentTypeFields = append(ct.ContentTypeFields, field)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	return &ct, nil
}



@@ 264,6 283,21 @@ func (db *DB) ContentTypeGet(u user.User, space space.Space, contenttypeID strin
// TODO: Consolidate with other list function here. They are the same except for
// the query used.
func (db *DB) ContentTypeSearch(u user.User, space space.Space, query string, before int) (contenttype.ContentTypeList, error) {
	t, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer t.Rollback()

	list, err := db.contentTypeSearch(t, u, space, query, before)
	if err != nil {
		return nil, err
	}

	return list, t.Commit()
}

func (db *DB) contentTypeSearch(t *sql.Tx, u user.User, space space.Space, query string, before int) (contenttype.ContentTypeList, error) {
	var (
		r       []contenttype.ContentType
		id      int


@@ 275,17 309,26 @@ func (db *DB) ContentTypeSearch(u user.User, space space.Space, query string, be
	// TODO: May want to make a temp table for this query for proper ordering.
	q := `SELECT ID FROM cms_contenttype WHERE NAME LIKE ? AND SPACE_ID = ? AND ID < ? ORDER BY ID DESC LIMIT ?`

	rows, err := db.Query(q, fmt.Sprintf("%%%s%%", query), space.ID(), before, perPage)
	rows, err := t.Query(q, fmt.Sprintf("%%%s%%", query), space.ID(), before, perPage)
	if err != nil {
		return nil, err
	}

	var ids []int
	for rows.Next() {
		if err := rows.Scan(&id); err != nil {
			return nil, err
		}

		ct, err := db.ContentTypeGet(u, space, strconv.Itoa(id))
		ids = append(ids, id)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, id := range ids {
		ct, err := db.contentTypeGet(t, space, strconv.Itoa(id))
		if err != nil {
			return nil, err
		}

M internal/s/db/db.go => internal/s/db/db.go +1 -1
@@ 175,7 175,7 @@ func (db *DB) migrate() error {

	for _, vt := range vtypes {
		var count int
		if err := db.QueryRow("SELECT COUNT(*) FROM cms_valuetype WHERE VALUE=?", count).Scan(&count); err != nil {
		if err := db.QueryRow("SELECT COUNT(*) FROM cms_valuetype WHERE VALUE=?", vt).Scan(&count); err != nil {
			return err
		}


M internal/s/db/hook.go => internal/s/db/hook.go +10 -1
@@ 111,11 111,12 @@ func (db *DB) hooksPerSpace(t *sql.Tx, space space.Space, before int) (hook.Hook
		ORDER BY ID DESC LIMIT ?
	 `

	rows, err := db.Query(q, space.ID(), before, perPage+1)
	rows, err := t.Query(q, space.ID(), before, perPage+1)
	if err != nil {
		return nil, err
	}

	var ids []int
	for i := 0; rows.Next(); i++ {
		if i == perPage {
			hasMore = true


@@ 126,6 127,14 @@ func (db *DB) hooksPerSpace(t *sql.Tx, space space.Space, before int) (hook.Hook
			return nil, err
		}

		ids = append(ids, id)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, id := range ids {
		ct, err := db.hookGet(t, space, strconv.Itoa(id))
		if err != nil {
			return nil, err

M internal/s/db/invite.go => internal/s/db/invite.go +23 -2
@@ 164,23 164,44 @@ func (db *DB) inviteAccept(t *sql.Tx, i invite.Invite, u, p, v string) (user.Use
}

func (db *DB) InviteList(u user.User, o org.Org) (r []invite.Invite, err error) {
	t, err := db.Begin()
	defer t.Rollback()

	list, err := db.inviteList(t, u, o)
	if err != nil {
		return nil, err
	}

	return list, t.Commit()
}

func (db *DB) inviteList(t *sql.Tx, u user.User, o org.Org) (r []invite.Invite, err error) {
	var (
		now  = time.Now().UTC()
		from = now.Format(mysqlTimeLayout)
	)

	rows, err := db.Query(queryList, o.ID(), from)
	rows, err := t.Query(queryList, o.ID(), from)
	if err != nil {
		return nil, err
	}

	var ids []string
	for rows.Next() {
		var id string
		if err := rows.Scan(&id); err != nil {
			return nil, err
		}

		i, err := db.InviteGet(u, id)
		ids = append(ids, id)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, id := range ids {
		i, err := db.inviteGet(t, id)
		if err != nil {
			return nil, err
		}

M internal/s/db/org.go => internal/s/db/org.go +32 -2
@@ 94,12 94,27 @@ func (db *DB) OrgUpdateTier(u user.User, o org.Org, t tier.Tier, paymentCustomer
}

func (db *DB) OrgGetSpaceCount(o org.Org) (int, error) {
	t, err := db.Begin()
	if err != nil {
		return 0, err
	}
	defer t.Rollback()

	i, err := db.orgGetSpaceCount(t, o)
	if err != nil {
		return 0, err
	}

	return i, t.Commit()
}

func (db *DB) orgGetSpaceCount(t *sql.Tx, o org.Org) (int, error) {
	var (
		count int
		q     = "SELECT COUNT(*) FROM cms_space WHERE cms_space.ORG_ID=?"
	)

	if err := db.QueryRow(q, o.ID()).Scan(&count); err != nil {
	if err := t.QueryRow(q, o.ID()).Scan(&count); err != nil {
		return 0, err
	}



@@ 107,12 122,27 @@ func (db *DB) OrgGetSpaceCount(o org.Org) (int, error) {
}

func (db *DB) OrgGetUserCount(o org.Org) (int, error) {
	t, err := db.Begin()
	if err != nil {
		return 0, err
	}
	defer t.Rollback()

	i, err := db.orgGetUserCount(t, o)
	if err != nil {
		return 0, err
	}

	return i, t.Commit()
}

func (db *DB) orgGetUserCount(t *sql.Tx, o org.Org) (int, error) {
	var (
		count int
		q     = "SELECT COUNT(*) FROM cms_user WHERE cms_user.ORG_ID=?"
	)

	if err := db.QueryRow(q, o.ID()).Scan(&count); err != nil {
	if err := t.QueryRow(q, o.ID()).Scan(&count); err != nil {
		return 0, err
	}


M internal/s/db/space.go => internal/s/db/space.go +10 -1
@@ 399,11 399,12 @@ func (db *DB) spacesPerUser(t *sql.Tx, user user.User, before int) (space.SpaceL
		ORDER BY cms_space.ID DESC LIMIT ?
	`

	rows, err := db.Query(q, user.ID(), before, perPage+1)
	rows, err := t.Query(q, user.ID(), before, perPage+1)
	if err != nil {
		return nil, err
	}

	var ids []int
	for i := 0; rows.Next(); i++ {
		if i == perPage {
			hasMore = true


@@ 414,6 415,14 @@ func (db *DB) spacesPerUser(t *sql.Tx, user user.User, before int) (space.SpaceL
			return nil, err
		}

		ids = append(ids, id)
	}

	if err := rows.Close(); err != nil {
		return nil, err
	}

	for _, id := range ids {
		s, err := db.spaceGet(t, user, strconv.Itoa(id))
		if err != nil {
			return nil, err

M internal/s/db/user.go => internal/s/db/user.go +34 -4
@@ 137,6 137,21 @@ func (db *DB) userGet(t *sql.Tx, username, password string) (user.User, error) {
}

func (db *DB) UserGetFromToken(token string) (user.User, error) {
	t, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer t.Rollback()

	user, err := db.userGetFromToken(t, token)
	if err != nil {
		return nil, err
	}

	return user, t.Commit()
}

func (db *DB) userGetFromToken(t *sql.Tx, token string) (user.User, error) {
	tmap, err := db.sec.TokenFrom(token)
	if err != nil {
		return nil, fmt.Errorf("failed to decode user token")


@@ 148,7 163,7 @@ func (db *DB) UserGetFromToken(token string) (user.User, error) {
	}

	var user User
	if err := db.QueryRow(queryFindUserByID, id).Scan(
	if err := t.QueryRow(queryFindUserByID, id).Scan(
		&user.UserID, &user.UserName, &user.userHash, &user.userEmail,
		&user.UserOrg.OrgID, &user.UserOrg.OrgBillingTierName, &user.UserOrg.OrgPaymentCustomer,
		&user.UserRole.RoleID, &user.UserRole.RoleName,


@@ 176,8 191,23 @@ func (db *DB) UserSetEmail(u user.User, email string) (user.User, error) {
}

func (db *DB) UserSetPassword(u user.User, current, password, verifyPassword string) (user.User, error) {
	t, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer t.Rollback()

	user, err := db.userSetPassword(t, u, current, password, verifyPassword)
	if err != nil {
		return nil, err
	}

	return user, t.Commit()
}

func (db *DB) userSetPassword(t *sql.Tx, u user.User, current, password, verifyPassword string) (user.User, error) {
	var currentHash string
	if err := db.QueryRow("SELECT HASH FROM cms_user WHERE ID=?", u.ID()).Scan(&currentHash); err != nil {
	if err := t.QueryRow("SELECT HASH FROM cms_user WHERE ID=?", u.ID()).Scan(&currentHash); err != nil {
		return nil, err
	}



@@ 200,11 230,11 @@ func (db *DB) UserSetPassword(u user.User, current, password, verifyPassword str
		return nil, fmt.Errorf("failed to create password hash")
	}

	if _, err := db.Exec("UPDATE cms_user SET HASH=? WHERE ID=?", hash, u.ID()); err != nil {
	if _, err := t.Exec("UPDATE cms_user SET HASH=? WHERE ID=?", hash, u.ID()); err != nil {
		return nil, err
	}

	return db.UserGet(u.Name(), password)
	return db.userGet(t, u.Name(), password)
}

func (u *User) ID() string     { return u.UserID }