~emersion/drmdb

00da602b77a521373b57bb76e2fdc592e70d3605 — Simon Ser 5 months ago 8b711b9 sqlite
Migrate to SQLite
5 files changed, 103 insertions(+), 120 deletions(-)

M database/db.go
R database/{fs.go => key.go}
M go.mod
M go.sum
M snapshot.go
M database/db.go => database/db.go +98 -44
@@ 1,79 1,133 @@
package database

import (
	"database/sql"
	"encoding/json"
	"errors"
	"fmt"

	"github.com/mattn/go-sqlite3"

	"git.sr.ht/~emersion/drmdb/drmtree"
)

const cacheSize = 1000
var Filename = "main.db"

type cacheEntry struct {
	key  string
	node *drmtree.Node
}
var ErrStop = fmt.Errorf("drmdb: stop walking")

const schema = `
CREATE TABLE Snapshot (
	key VARCHAR PRIMARY KEY,
	raw TEXT NOT NULL
);
`

type DB struct {
	cache   []cacheEntry   // ring buffer
	indexes map[string]int // key → cache index
	cur     int
	db *sql.DB
}

func Open() (*DB, error) {
	if err := initDB(); err != nil {
	sqliteDB, err := sql.Open("sqlite3", Filename+"?cache=shared")
	if err != nil {
		return nil, err
	}
	return &DB{
		cache:   make([]cacheEntry, cacheSize),
		indexes: make(map[string]int, cacheSize),
	}, nil
	sqliteDB.SetMaxOpenConns(1)
	db := &DB{sqliteDB}
	return db, db.upgrade()
}

func (db *DB) storeCache(k string, n *drmtree.Node) {
	prev := db.cache[db.cur]
	delete(db.indexes, prev.key)
	db.cache[db.cur] = cacheEntry{key: k, node: n}
	db.indexes[k] = db.cur
	db.cur = (db.cur + 1) % len(db.cache)
}
func (db *DB) upgrade() error {
	var version int
	if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
		return fmt.Errorf("failed to query schema version: %v", err)
	}

	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.Rollback()

func (db *DB) loadCache(k string) *drmtree.Node {
	i, ok := db.indexes[k]
	if !ok {
		return nil
	if version == 0 {
		if _, err := tx.Exec(schema); err != nil {
			return fmt.Errorf("failed to initialize schema: %v", err)
		}
	}

	// For some reason prepared statements don't work here
	_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", 1))
	if err != nil {
		return fmt.Errorf("failed to bump schema version: %v", err)
	}
	return db.cache[i].node

	return tx.Commit()
}

func (db *DB) Store(n *drmtree.Node) (string, error) {
	k, err := store(n)
	if err == nil {
		db.storeCache(k, n)
	k, err := generateKey(n)
	if err != nil {
		return "", err
	}

	raw, err := json.Marshal(n)
	if err != nil {
		return "", err
	}

	var sqliteErr sqlite3.Error
	_, err = db.db.Exec(`INSERT INTO Snapshot(key, raw) VALUES (?, ?)`, k, raw)
	if errors.As(err, &sqliteErr) && sqliteErr.Code == sqlite3.ErrConstraint {
		return "", fmt.Errorf("snapshot has already been submitted")
	} else if err != nil {
		return "", err
	}
	return k, err

	return k, nil
}

func (db *DB) Load(k string) (*drmtree.Node, error) {
	if n := db.loadCache(k); n != nil {
		return n, nil
	}
	n, err := load(k)
	if err == nil {
		db.storeCache(k, n)
	var raw []byte
	err := db.db.QueryRow(`SELECT raw FROM Snapshot WHERE key = ?`, k).Scan(&raw)
	if err != nil {
		return nil, err
	}
	return n, err

	var node *drmtree.Node
	err = json.Unmarshal(raw, &node)
	return node, err
}

func (db *DB) Walk(fn func(k string, n *drmtree.Node) error) error {
	return walk(func(k string) error {
		n, err := db.Load(k)
		if err != nil {
	rows, err := db.db.Query(`SELECT key, raw FROM Snapshot`)
	if err != nil {
		return err
	}
	defer rows.Close()

	for rows.Next() {
		var (
			k   string
			raw []byte
		)
		if err := rows.Scan(&k, &raw); err != nil {
			return err
		}

		var node *drmtree.Node
		if err := json.Unmarshal(raw, &node); err != nil {
			return err
		}

		if err := fn(k, node); err == ErrStop {
			break
		} else if err != nil {
			return err
		}
		return fn(k, n)
	})
	}

	return nil
}

func (db *DB) Close() error {
	db.cache = nil
	db.indexes = nil
	return nil
	return db.db.Close()
}

R database/fs.go => database/key.go +0 -74
@@ 4,20 4,11 @@ import (
	"bytes"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"
	"strings"

	"git.sr.ht/~emersion/drmdb/drmtree"
)

var Dir = "db"

var ErrStop = fmt.Errorf("drmdb: stop walking")

func generateKey(n *drmtree.Node) (string, error) {
	if n.Driver == nil || n.Device == nil {
		return "", fmt.Errorf("node is missing driver/device")


@@ 64,68 55,3 @@ func generateKey(n *drmtree.Node) (string, error) {
	sum := sha256.Sum256(b.Bytes())
	return hex.EncodeToString(sum[:])[:12], nil
}

func initDB() error {
	return os.MkdirAll(Dir, 0755)
}

func store(n *drmtree.Node) (string, error) {
	k, err := generateKey(n)
	if err != nil {
		return "", err
	}

	p := filepath.Join(Dir, k+".json")
	f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666)
	if err != nil {
		if os.IsExist(err) {
			return "", fmt.Errorf("data has already been submitted")
		}
		return "", err
	}
	defer f.Close()

	if err := json.NewEncoder(f).Encode(n); err != nil {
		return "", err
	}

	return k, f.Close()
}

func load(k string) (*drmtree.Node, error) {
	p := filepath.Join(Dir, k+".json")
	f, err := os.Open(p)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	var n drmtree.Node
	if err := json.NewDecoder(f).Decode(&n); err != nil {
		return nil, err
	}

	return &n, f.Close()
}

func walk(fn func(k string) error) error {
	files, err := ioutil.ReadDir(Dir)
	if err != nil {
		return err
	}

	for _, fi := range files {
		if !strings.HasSuffix(fi.Name(), ".json") {
			continue
		}
		k := strings.TrimSuffix(fi.Name(), ".json")

		if err := fn(k); err == ErrStop {
			return nil
		} else if err != nil {
			return err
		}
	}

	return nil
}

M go.mod => go.mod +1 -0
@@ 7,6 7,7 @@ require (
	git.sr.ht/~emersion/go-hwids v0.0.0-20190518090256-f59e5efa82bd
	github.com/labstack/echo/v4 v4.11.1
	github.com/labstack/gommon v0.4.0
	github.com/mattn/go-sqlite3 v1.14.17 // indirect
	github.com/mcuadros/go-version v0.0.0-20190830083331-035f6764e8d2
	github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
	golang.org/x/net v0.14.0 // indirect

M go.sum => go.sum +2 -0
@@ 29,6 29,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mcuadros/go-version v0.0.0-20190830083331-035f6764e8d2 h1:YocNLcTBdEdvY3iDK6jfWXvEaM5OCKkjxPKoJRdB3Gg=
github.com/mcuadros/go-version v0.0.0-20190830083331-035f6764e8d2/go.mod h1:76rfSfYPWj01Z85hUf/ituArm797mNKcvINh1OlsZKo=
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=

M snapshot.go => snapshot.go +2 -2
@@ 67,13 67,13 @@ func writeSnapshot(w io.Writer) error {
		return err
	}

	files, err := ioutil.ReadDir(database.Dir)
	files, err := ioutil.ReadDir(database.Filename) // TODO
	if err != nil {
		return err
	}

	for _, fi := range files {
		p := filepath.Join(database.Dir, fi.Name())
		p := filepath.Join(database.Filename, fi.Name())
		if err := writeSnapshotFile(tw, p, fi); err != nil {
			return err
		}