~evanj/cms

ref: 9d7fbcbdf7e7332ea56493092efe5cffbadccdd6 cms/internal/s/db/db.go -rw-r--r-- 4.5 KiB
9d7fbcbdEvan M Jones Fix(go.mod/vendor): Attempting to fix vendoring for deploys. 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package db

import (
	"database/sql"
	"log"
	"sort"
	"strconv"
	"strings"

	"git.sr.ht/~evanj/cms/internal/m/valuetype"
	"git.sr.ht/~evanj/security"
	_ "github.com/go-sql-driver/mysql"
)

//go:generate embed -pattern */* -id migrations

func newMigrationSlice(m map[string]string) [][]string {
	migrations := make([][]string, len(m))
	i := 0

	for key, val := range m {
		migrations[i] = []string{key, val}
		i++
	}

	sort.Slice(migrations, func(i, j int) bool {
		iKey := migrations[i][0]
		jKey := migrations[j][0]
		return strings.Compare(iKey, jKey) < 1
	})

	return migrations
}

const (
	// Default pagination amount. For use in LIMIT/OFFSET.
	perPage = 25

	// For fetching reference types.
	// maxDepth = 3 // Not used yet.
	defaultDepth = 2

	maxUint = ^uint(0)
	maxInt  = int(maxUint >> 1)
)

var (
	// Max before value to be used for pagination when user has specified zero
	// value.
	maxBefore = strconv.Itoa(maxInt)
	zero      int
)

func beformat(before int) int {
	if before == zero {
		return maxInt
	}
	return before
}

type DB struct {
	*sql.DB
	log      *log.Logger
	sec      securer
	setupErr error
}

// securer provides us two things:
// 	 1. Creating user tokens (for use in cookie/other).
// 	 2. Creating salt+hashes for passwords.
type securer interface {
	TokenCreate(val security.TokenMap) (string, error)
	TokenFrom(tokenString string) (security.TokenMap, error)
	HashCreate(salt, pass string) (string, error)
	HashCompare(salt, pass, hash string) error
}

// New, does as one might expect, given a logger, type of database, database
// connection string, and securer interface, opens a pool'd connection to a
// mysql database and pings. If ping fails we return error and nil *DB.
func New(log *log.Logger, typ, creds string, sec securer) *DB {
	db, err := sql.Open(typ, creds)
	return &DB{db, log, sec, err}
}

func (db *DB) Setup() error {
	if db.setupErr != nil {
		return db.setupErr
	}

	if err := db.Ping(); err != nil {
		return err
	}

	if err := db.migrate(); err != nil {
		return err
	}

	// TODO: Best numbers?
	db.SetMaxIdleConns(10)
	db.SetMaxOpenConns(100)

	return nil
}

// migrate does our "migration" -migration in quotes as we just dummy
// attempt to create tables on every server startup and ignore "table already
// exists" errors.
func (db *DB) migrate() error {
	var (
		err error
		_   interface{}
	)

	for _, migrationSet := range newMigrationSlice(migrations) {
		key := migrationSet[0]
		m := migrationSet[1]

		var count int
		if err := db.QueryRow("SELECT COUNT(*) FROM cms_migrate WHERE NAME=?", key).Scan(&count); err != nil {
			// Catch first error of DB setup.
			if !strings.Contains(err.Error(), "cms_migrate' doesn't exist") {
				return err
			}
		}

		if count > 0 {
			continue
		}

		t, err := db.Begin()
		if err != nil {
			return err
		}
		defer t.Rollback()

		for _, q := range strings.Split(m, ";") {
			q = strings.TrimSpace(q)
			if q == "" {
				continue
			}
			if _, err := t.Exec(q); err != nil {
				return err
			}
		}

		if err := t.Commit(); err != nil {
			return err
		}

		if _, err := db.Exec("INSERT INTO cms_migrate (NAME) VALUES (?)", key); err != nil {
			return err
		}
	}

	vtypes := []valuetype.ValueTypeEnum{
		valuetype.StringSmall,
		valuetype.StringBig,
		valuetype.InputHTML,
		valuetype.InputMarkdown,
		valuetype.File,
		valuetype.Date,
		valuetype.Reference,
		valuetype.ReferenceList,
	}

	for _, vt := range vtypes {
		var count int
		if err := db.QueryRow("SELECT COUNT(*) FROM cms_valuetype WHERE VALUE=?", count).Scan(&count); err != nil {
			return err
		}

		if count > 0 {
			continue
		}

		if _, err = db.Exec(`INSERT INTO cms_valuetype (VALUE) values (?);`, vt); err != nil {
			return err
		}
	}

	return nil
}

// FileExists makes sure SOME space and content owns the file. I.E. deleted
// spaces can't server files.
func (db *DB) FileExists(URL string) (bool, error) {
	q := `
		SELECT cms_space.ID FROM cms_value_string_small 
		JOIN cms_value ON cms_value.VALUE_ID = cms_value_string_small.ID
		JOIN cms_contenttype_to_valuetype ON cms_value.CONTENTTYPE_TO_VALUETYPE_ID = cms_contenttype_to_valuetype.ID 
		JOIN cms_valuetype ON cms_contenttype_to_valuetype.VALUETYPE_ID = cms_valuetype.ID
		JOIN cms_contenttype ON cms_contenttype_to_valuetype.CONTENTTYPE_ID = cms_contenttype.ID
		JOIN cms_content ON cms_content.CONTENTTYPE_ID = cms_contenttype.ID
		JOIN cms_space ON cms_contenttype.SPACE_ID = cms_space.ID
		WHERE cms_valuetype.VALUE = ? AND cms_value_string_small.VALUE = ?
	`

	var spaceID string
	if err := db.QueryRow(q, valuetype.File, URL).Scan(&spaceID); err != nil {
		return false, err
	}

	return true, nil
}