M TODO => TODO +2 -0
@@ 1,4 1,6 @@
[high]
+db.SetMaxOpenConns(1): test for deadlocks in DB code (remove all db.Query and
+ db.Exec in favor of t.Query and t.Exec)
Testing: 100% happy path and 80% total
Documentation
Doc pages: Contact, FAQ, Terms, Privacy, Tour
M internal/s/db/action.go => internal/s/db/action.go +17 -1
@@ 1,6 1,7 @@
package db
import (
+ "database/sql"
"time"
"git.sr.ht/~evanj/cms/internal/m/org"
@@ 12,6 13,21 @@ func (db *DB) ActionNew(o org.Org, at time.Time) error {
}
func (db *DB) ActionGetCount(o org.Org, from, to time.Time) (int, error) {
+ t, err := db.Begin()
+ if err != nil {
+ return 0, err
+ }
+ defer t.Rollback()
+
+ i, err := db.actionGetCount(t, o, from, to)
+ if err != nil {
+ return 0, err
+ }
+
+ return i, t.Commit()
+}
+
+func (db *DB) actionGetCount(t *sql.Tx, o org.Org, from, to time.Time) (int, error) {
var (
count int
q = "SELECT COUNT(*) FROM cms_action WHERE cms_action.ORG_ID=? AND AT>? AND AT<?"
@@ 19,7 35,7 @@ func (db *DB) ActionGetCount(o org.Org, from, to time.Time) (int, error) {
a := from.Format("2006-01-02 03:04:05")
b := to.Format("2006-01-02 03:04:05")
- if err := db.QueryRow(q, o.ID(), a, b).Scan(&count); err != nil {
+ if err := t.QueryRow(q, o.ID(), a, b).Scan(&count); err != nil {
return 0, err
}
M internal/s/db/content.go => internal/s/db/content.go +64 -19
@@ 532,7 532,7 @@ type ContentRefSet struct {
ContentTypeID, ContentID string
}
-func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefSet, err error) {
+func (db *DB) contentRefererList(t *sql.Tx, contentID string, depth int) (ret []ContentRefSet, err error) {
// Cap total recursion to defaultDepth
if depth < 1 {
return ret, nil
@@ 559,19 559,27 @@ func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefS
AND cms_value_reference_list_values.CONTENT_ID = ?
`
- rows, err := db.Query(refQ, valuetype.Reference, contentID)
+ rows, err := t.Query(refQ, valuetype.Reference, contentID)
if err != nil {
return ret, err
}
- defer rows.Close()
+ var refs []ContentRefSet
for rows.Next() {
var ref ContentRefSet
if err := rows.Scan(&ref.ContentTypeID, &ref.ContentID); err != nil {
return ret, err
}
- nested, err := db.contentRefererList(ref.ContentID, depth)
+ refs = append(refs, ref)
+ }
+
+ if err := rows.Close(); err != nil {
+ return ret, err
+ }
+
+ for _, ref := range refs {
+ nested, err := db.contentRefererList(t, ref.ContentID, depth)
if err != nil {
return nil, err
}
@@ 580,19 588,27 @@ func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefS
ret = append(ret, nested...)
}
- rows, err = db.Query(refListQ, valuetype.ReferenceList, contentID)
+ rows, err = t.Query(refListQ, valuetype.ReferenceList, contentID)
if err != nil {
return ret, err
}
- defer rows.Close()
+ refs = refs[0:0] // Empty slice.
for rows.Next() {
var ref ContentRefSet
if err := rows.Scan(&ref.ContentTypeID, &ref.ContentID); err != nil {
return ret, err
}
- nested, err := db.contentRefererList(ref.ContentID, depth)
+ refs = append(refs, ref)
+ }
+
+ if err := rows.Close(); err != nil {
+ return ret, err
+ }
+
+ for _, ref := range refs {
+ nested, err := db.contentRefererList(t, ref.ContentID, depth)
if err != nil {
return nil, err
}
@@ 606,7 622,18 @@ func (db *DB) contentRefererList(contentID string, depth int) (ret []ContentRefS
// ContentRefererList will retreive all content IDs that references a given piece of content.
func (db *DB) ContentRefererList(contentID string) (ret []ContentRefSet, err error) {
- return db.contentRefererList(contentID, defaultDepth)
+ t, err := db.Begin()
+ if err != nil {
+ return ret, err
+ }
+ defer t.Rollback()
+
+ list, err := db.contentRefererList(t, contentID, defaultDepth)
+ if err != nil {
+ return ret, err
+ }
+
+ return list, t.Commit()
}
func (db *DB) contentUpdate(t *sql.Tx, space space.Space, ct contenttype.ContentType, content content.Content, newParams []ContentNewParam, updateParams []ContentUpdateParam) error {
@@ 847,8 874,8 @@ func (db *DB) contentPerContentType(t *sql.Tx, u user.User, space space.Space, c
if err != nil {
return nil, err
}
- defer rows.Close()
+ var ids []string
for i := 0; rows.Next(); i++ {
if i == perPage {
hasMore = true
@@ 859,7 886,15 @@ func (db *DB) contentPerContentType(t *sql.Tx, u user.User, space space.Space, c
return nil, err
}
- c, err := db.ContentGet(u, space, ct, tmpContentID)
+ ids = append(ids, tmpContentID)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, id := range ids {
+ c, err := db.contentGet(t, space, ct, id, defaultDepth)
if err != nil {
return nil, err
}
@@ 951,9 986,8 @@ func (db *DB) ContentSearch(u user.User, space space.Space, ct contenttype.Conte
if err != nil {
return nil, err
}
- // Handled below.
- // defer rows.Close()
+ var ids []string
for i := 0; rows.Next(); i++ {
if i == perPage {
hasMore = true
@@ 964,7 998,15 @@ func (db *DB) ContentSearch(u user.User, space space.Space, ct contenttype.Conte
return nil, err
}
- c, err := db.ContentGet(u, space, ct, tmpContentID)
+ ids = append(ids, tmpContentID)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, tmpContentID := range ids {
+ c, err := db.contentGet(t, space, ct, tmpContentID, defaultDepth)
if err != nil {
return nil, err
}
@@ 972,10 1014,6 @@ func (db *DB) ContentSearch(u user.User, space space.Space, ct contenttype.Conte
r = append(r, c)
}
- if err := rows.Close(); err != nil {
- return nil, err
- }
-
return newContentList(r, hasMore, tmpID), t.Commit()
}
@@ 987,18 1025,25 @@ func (db *DB) contentGet(t *sql.Tx, space space.Space, ct contenttype.ContentTyp
// TODO: For some reason t.Query(...) is causing errors here.
// See: https://github.com/go-sql-driver/mysql/issues/314
- rows, err := db.Query(queryValueListByContent, content.ID(), content.ID(), content.ID(), content.ID(), content.ID())
+ rows, err := t.Query(queryValueListByContent, content.ID(), content.ID(), content.ID(), content.ID(), content.ID())
if err != nil {
return nil, fmt.Errorf("failed to find value(s)")
}
- defer rows.Close()
+ var values []ContentValue
for rows.Next() {
var value ContentValue
if err := rows.Scan(&value.FieldID, &value.FieldType, &value.FieldName, &value.FieldValue); err != nil {
return nil, fmt.Errorf("failed to scan values(s)")
}
+ values = append(values, value)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ for _, value := range values {
if err := db.contentValueAttachRef(t, &value, depth); err != nil {
return nil, err
}
M internal/s/db/contenttype.go => internal/s/db/contenttype.go +46 -3
@@ 86,6 86,7 @@ func (db *DB) ContentTypeNew(u user.User, space space.Space, name string, params
}
for _, item := range params {
+ db.log.Println(item, id)
if _, err := t.Exec(queryCreateContentTypeConnection, item.Name, id, item.Type); err != nil {
return nil, fmt.Errorf("failed to create field(s)")
}
@@ 108,6 109,10 @@ func (db *DB) ContentTypeNew(u user.User, space space.Space, name string, params
ct.ContentTypeFields = append(ct.ContentTypeFields, field)
}
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
if len(ct.ContentTypeFields) != len(params) {
return nil, fmt.Errorf("failed to create all fields")
}
@@ 184,11 189,12 @@ func (db *DB) contentTypesPerSpace(t *sql.Tx, space space.Space, before int) (co
ORDER BY ID DESC LIMIT ?
`
- rows, err := db.Query(q, space.ID(), before, perPage+1)
+ rows, err := t.Query(q, space.ID(), before, perPage+1)
if err != nil {
return nil, err
}
+ var ids []int
for i := 0; rows.Next(); i++ {
if i == perPage {
hasMore = true
@@ 199,6 205,14 @@ func (db *DB) contentTypesPerSpace(t *sql.Tx, space space.Space, before int) (co
return nil, err
}
+ ids = append(ids, id)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, id := range ids {
ct, err := db.contentTypeGet(t, space, strconv.Itoa(id))
if err != nil {
return nil, err
@@ 235,6 249,7 @@ func (db *DB) contentTypeGet(t *sql.Tx, space space.Space, contenttypeID string)
if err != nil {
return nil, fmt.Errorf("failed to find field(s)")
}
+
for rows.Next() {
var field ContentTypeField
if err := rows.Scan(&field.FieldID, &field.FieldName, &field.FieldType); err != nil {
@@ 243,6 258,10 @@ func (db *DB) contentTypeGet(t *sql.Tx, space space.Space, contenttypeID string)
ct.ContentTypeFields = append(ct.ContentTypeFields, field)
}
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
return &ct, nil
}
@@ 264,6 283,21 @@ func (db *DB) ContentTypeGet(u user.User, space space.Space, contenttypeID strin
// TODO: Consolidate with other list function here. They are the same except for
// the query used.
func (db *DB) ContentTypeSearch(u user.User, space space.Space, query string, before int) (contenttype.ContentTypeList, error) {
+ t, err := db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ defer t.Rollback()
+
+ list, err := db.contentTypeSearch(t, u, space, query, before)
+ if err != nil {
+ return nil, err
+ }
+
+ return list, t.Commit()
+}
+
+func (db *DB) contentTypeSearch(t *sql.Tx, u user.User, space space.Space, query string, before int) (contenttype.ContentTypeList, error) {
var (
r []contenttype.ContentType
id int
@@ 275,17 309,26 @@ func (db *DB) ContentTypeSearch(u user.User, space space.Space, query string, be
// TODO: May want to make a temp table for this query for proper ordering.
q := `SELECT ID FROM cms_contenttype WHERE NAME LIKE ? AND SPACE_ID = ? AND ID < ? ORDER BY ID DESC LIMIT ?`
- rows, err := db.Query(q, fmt.Sprintf("%%%s%%", query), space.ID(), before, perPage)
+ rows, err := t.Query(q, fmt.Sprintf("%%%s%%", query), space.ID(), before, perPage)
if err != nil {
return nil, err
}
+ var ids []int
for rows.Next() {
if err := rows.Scan(&id); err != nil {
return nil, err
}
- ct, err := db.ContentTypeGet(u, space, strconv.Itoa(id))
+ ids = append(ids, id)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, id := range ids {
+ ct, err := db.contentTypeGet(t, space, strconv.Itoa(id))
if err != nil {
return nil, err
}
M internal/s/db/db.go => internal/s/db/db.go +1 -1
@@ 175,7 175,7 @@ func (db *DB) migrate() error {
for _, vt := range vtypes {
var count int
- if err := db.QueryRow("SELECT COUNT(*) FROM cms_valuetype WHERE VALUE=?", count).Scan(&count); err != nil {
+ if err := db.QueryRow("SELECT COUNT(*) FROM cms_valuetype WHERE VALUE=?", vt).Scan(&count); err != nil {
return err
}
M internal/s/db/hook.go => internal/s/db/hook.go +10 -1
@@ 111,11 111,12 @@ func (db *DB) hooksPerSpace(t *sql.Tx, space space.Space, before int) (hook.Hook
ORDER BY ID DESC LIMIT ?
`
- rows, err := db.Query(q, space.ID(), before, perPage+1)
+ rows, err := t.Query(q, space.ID(), before, perPage+1)
if err != nil {
return nil, err
}
+ var ids []int
for i := 0; rows.Next(); i++ {
if i == perPage {
hasMore = true
@@ 126,6 127,14 @@ func (db *DB) hooksPerSpace(t *sql.Tx, space space.Space, before int) (hook.Hook
return nil, err
}
+ ids = append(ids, id)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, id := range ids {
ct, err := db.hookGet(t, space, strconv.Itoa(id))
if err != nil {
return nil, err
M internal/s/db/invite.go => internal/s/db/invite.go +23 -2
@@ 164,23 164,44 @@ func (db *DB) inviteAccept(t *sql.Tx, i invite.Invite, u, p, v string) (user.Use
}
func (db *DB) InviteList(u user.User, o org.Org) (r []invite.Invite, err error) {
+ t, err := db.Begin()
+ defer t.Rollback()
+
+ list, err := db.inviteList(t, u, o)
+ if err != nil {
+ return nil, err
+ }
+
+ return list, t.Commit()
+}
+
+func (db *DB) inviteList(t *sql.Tx, u user.User, o org.Org) (r []invite.Invite, err error) {
var (
now = time.Now().UTC()
from = now.Format(mysqlTimeLayout)
)
- rows, err := db.Query(queryList, o.ID(), from)
+ rows, err := t.Query(queryList, o.ID(), from)
if err != nil {
return nil, err
}
+ var ids []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
- i, err := db.InviteGet(u, id)
+ ids = append(ids, id)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, id := range ids {
+ i, err := db.inviteGet(t, id)
if err != nil {
return nil, err
}
M internal/s/db/org.go => internal/s/db/org.go +32 -2
@@ 94,12 94,27 @@ func (db *DB) OrgUpdateTier(u user.User, o org.Org, t tier.Tier, paymentCustomer
}
func (db *DB) OrgGetSpaceCount(o org.Org) (int, error) {
+ t, err := db.Begin()
+ if err != nil {
+ return 0, err
+ }
+ defer t.Rollback()
+
+ i, err := db.orgGetSpaceCount(t, o)
+ if err != nil {
+ return 0, err
+ }
+
+ return i, t.Commit()
+}
+
+func (db *DB) orgGetSpaceCount(t *sql.Tx, o org.Org) (int, error) {
var (
count int
q = "SELECT COUNT(*) FROM cms_space WHERE cms_space.ORG_ID=?"
)
- if err := db.QueryRow(q, o.ID()).Scan(&count); err != nil {
+ if err := t.QueryRow(q, o.ID()).Scan(&count); err != nil {
return 0, err
}
@@ 107,12 122,27 @@ func (db *DB) OrgGetSpaceCount(o org.Org) (int, error) {
}
func (db *DB) OrgGetUserCount(o org.Org) (int, error) {
+ t, err := db.Begin()
+ if err != nil {
+ return 0, err
+ }
+ defer t.Rollback()
+
+ i, err := db.orgGetUserCount(t, o)
+ if err != nil {
+ return 0, err
+ }
+
+ return i, t.Commit()
+}
+
+func (db *DB) orgGetUserCount(t *sql.Tx, o org.Org) (int, error) {
var (
count int
q = "SELECT COUNT(*) FROM cms_user WHERE cms_user.ORG_ID=?"
)
- if err := db.QueryRow(q, o.ID()).Scan(&count); err != nil {
+ if err := t.QueryRow(q, o.ID()).Scan(&count); err != nil {
return 0, err
}
M internal/s/db/space.go => internal/s/db/space.go +10 -1
@@ 399,11 399,12 @@ func (db *DB) spacesPerUser(t *sql.Tx, user user.User, before int) (space.SpaceL
ORDER BY cms_space.ID DESC LIMIT ?
`
- rows, err := db.Query(q, user.ID(), before, perPage+1)
+ rows, err := t.Query(q, user.ID(), before, perPage+1)
if err != nil {
return nil, err
}
+ var ids []int
for i := 0; rows.Next(); i++ {
if i == perPage {
hasMore = true
@@ 414,6 415,14 @@ func (db *DB) spacesPerUser(t *sql.Tx, user user.User, before int) (space.SpaceL
return nil, err
}
+ ids = append(ids, id)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+
+ for _, id := range ids {
s, err := db.spaceGet(t, user, strconv.Itoa(id))
if err != nil {
return nil, err
M internal/s/db/user.go => internal/s/db/user.go +34 -4
@@ 137,6 137,21 @@ func (db *DB) userGet(t *sql.Tx, username, password string) (user.User, error) {
}
func (db *DB) UserGetFromToken(token string) (user.User, error) {
+ t, err := db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ defer t.Rollback()
+
+ user, err := db.userGetFromToken(t, token)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, t.Commit()
+}
+
+func (db *DB) userGetFromToken(t *sql.Tx, token string) (user.User, error) {
tmap, err := db.sec.TokenFrom(token)
if err != nil {
return nil, fmt.Errorf("failed to decode user token")
@@ 148,7 163,7 @@ func (db *DB) UserGetFromToken(token string) (user.User, error) {
}
var user User
- if err := db.QueryRow(queryFindUserByID, id).Scan(
+ if err := t.QueryRow(queryFindUserByID, id).Scan(
&user.UserID, &user.UserName, &user.userHash, &user.userEmail,
&user.UserOrg.OrgID, &user.UserOrg.OrgBillingTierName, &user.UserOrg.OrgPaymentCustomer,
&user.UserRole.RoleID, &user.UserRole.RoleName,
@@ 176,8 191,23 @@ func (db *DB) UserSetEmail(u user.User, email string) (user.User, error) {
}
func (db *DB) UserSetPassword(u user.User, current, password, verifyPassword string) (user.User, error) {
+ t, err := db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ defer t.Rollback()
+
+ user, err := db.userSetPassword(t, u, current, password, verifyPassword)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, t.Commit()
+}
+
+func (db *DB) userSetPassword(t *sql.Tx, u user.User, current, password, verifyPassword string) (user.User, error) {
var currentHash string
- if err := db.QueryRow("SELECT HASH FROM cms_user WHERE ID=?", u.ID()).Scan(¤tHash); err != nil {
+ if err := t.QueryRow("SELECT HASH FROM cms_user WHERE ID=?", u.ID()).Scan(¤tHash); err != nil {
return nil, err
}
@@ 200,11 230,11 @@ func (db *DB) UserSetPassword(u user.User, current, password, verifyPassword str
return nil, fmt.Errorf("failed to create password hash")
}
- if _, err := db.Exec("UPDATE cms_user SET HASH=? WHERE ID=?", hash, u.ID()); err != nil {
+ if _, err := t.Exec("UPDATE cms_user SET HASH=? WHERE ID=?", hash, u.ID()); err != nil {
return nil, err
}
- return db.UserGet(u.Name(), password)
+ return db.userGet(t, u.Name(), password)
}
func (u *User) ID() string { return u.UserID }