// sql - A tool for querying databases.
// Copyright (C) 2020 Noel Cower
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// Command sql is a tool for querying relational databases, such as MySQL and
// Postgres.
package main
import (
"bufio"
"context"
"database/sql"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/adrg/xdg"
"github.com/google/shlex"
"go.spiff.io/go-ini"
"go.spiff.io/sql/driver"
"go.spiff.io/sql/internal/cli"
"go.spiff.io/sql/vdb"
)
type recorderFunc func(key, value string) error
func (fn recorderFunc) Add(key, value string) error {
return fn(key, value)
}
func loadConfigFile(path string, dsn *string, fs *flag.FlagSet) error {
f, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to open config file %q: %w", path, err)
}
defer f.Close()
visited := map[string]bool{}
fs.Visit(func(f *flag.Flag) { visited[f.Name] = true })
var set recorderFunc = func(key, value string) error {
flagName := strings.TrimPrefix(key, "sql.")
if flagName == key || len(flagName) <= 1 {
log.Printf("%s: Unrecognized config key %q", path, key)
return nil
}
if flagName == "dsn" {
if *dsn == "" {
*dsn = value
}
return nil
}
if visited[flagName] {
return nil
}
f := fs.Lookup(flagName)
if f == nil {
log.Printf("%s: Unrecognized config value for key %q", path, key)
return nil
}
err := fs.Set(flagName, value)
if err != nil {
return fmt.Errorf("%s: error parsing config value for key %q: %w", path, key, err)
}
return err
}
reader := ini.Reader{
Separator: ".",
Casing: ini.LowerCase,
True: "true",
}
rb := bufio.NewReader(f)
if err := reader.Read(rb, set); err != nil {
if re := (*ini.RecordingError)(nil); errors.As(err, &re) {
return re.Err
}
return fmt.Errorf("error parsing config file %q: %w", path, err)
}
return nil
}
func main() {
var configFilePaths []string
noConfig := false
argv := os.Args[1:]
dsnArg := ""
sep := ","
argSep := ",,"
inTransaction := false
printArray := false
isolationLevel := sql.LevelDefault
opts := vdb.QueryOptions{
TimeFormat: vdb.TimeString,
}
fs := flag.NewFlagSet("sql", flag.ContinueOnError)
fs.Usage = func() {
fmt.Fprint(fs.Output(),
`Usage: sql [OPTIONS] <DSN> {QUERY [ARGS...]...}
Options:
-f, --config FILE
Load INI configuration for subsequent flags from a config file.
Defaults to searching for $XDG_CONFIG_HOME/sql/sqlrc and
$HOME/.sqlrc.
-F, --no-config
Disable config file loading, including for any config files passed
via -f, --config.
-s, --sep SEP (default: ",")
Delimit queries with the given separator.
Must not be an empty string or the same as -d, --arg-sep.
-d, --arg-sep SEP (default: ",,")
Delimit groups of arguments to a query with the given separator.
Must not be an empty string or the same as -s, --sep.
-c, --compact
Print compacted JSON output, with each object or array separated
only by a newline.
-p, --pretty-print (default)
Print formatted and indented JSON output.
-a, --array
Write output rows as JSON arrays, with one array per query
execution.
-A, --no-array (default)
Write output rows as a stream of JSON objects.
-j, --json
Attempt to parse all fields as JSON. By default, only column types
recognized as JSON are parsed as JSON.
-J, --no-json
Parse no fields as JSON. If set, columns with known JSON types will
not be parsed as JSON.
-t, --time-format FORMAT
Set the time format to one of the following time formats:
* ts, rfc3339, str -- RFC 3339 formatted timestamp.
* unix, s, sec -- A Unix integer timestamp in seconds.
* unixms, ms, msec -- A Unix integer timestamp in milliseconds.
* unixus, us, usec -- A Unix integer timestamp in microseconds.
* unixns, ns, nsec -- A Unix integer timestamp in nanoseconds.
* unixf, fs, fsec, float -- A Unix timestamp in floating point,
with sub-seconds represented as a fraction.
* format:LAYOUT, +LAYOUT -- Render the timestamp using the given
LAYOUT. Must be a valid Go time layout.
-x, --transaction
Run all queries inside of a transaction.
-i, --isolation-level LEVEL (default: DB preference)
Set the transaction isolation level for all queries.
-u, --username USERNAME
-p, --password PASSWORD
Set or override the username or password on the DB URL.
-h, --help
Print this usage text.
`)
}
// -s, --sep SEP
fs.Var(cli.String(&sep).WithFile(), "sep",
"set the query `separator`")
fs.Var(cli.Forward(fs, "sep"), "s",
"set the query `separator`")
// -d, --arg-sep SEP
fs.Var(cli.String(&argSep).WithFile(), "arg-sep",
"set the argument `separator` (to run the same query with different args)")
fs.Var(cli.Forward(fs, "arg-sep"), "d",
"set the argument `separator` (to run the same query with different args)")
// -c, --compact
fs.Var(cli.Bool(&opts.Compact).WithFile(), "compact",
"print compact JSON")
fs.Var(cli.Forward(fs, "compact"), "c",
"print compact JSON")
// -p, --pretty-print
fs.Var(cli.NegBool(fs, "compact").WithFile(), "pretty-print",
"pretty-print JSON")
fs.Var(cli.Forward(fs, "pretty-print"), "p",
"pretty-print JSON")
// -a, --array
fs.Var(cli.Bool(&printArray).WithFile(), "array",
"print output as an array of objects")
fs.Var(cli.Forward(fs, "array"), "a",
"print output as an array of objects")
// -A, --no-array
fs.Var(cli.NegBool(fs, "array").WithFile(), "no-array",
"print output a JSON object stream")
fs.Var(cli.Forward(fs, "no-array"), "A",
"print output as a JSON object stream")
// -j, --json
fs.Var(cli.Bool(&opts.TryJSON).WithFile(), "json",
"attempt to parse columns as JSON")
fs.Var(cli.Forward(fs, "json"), "j",
"attempt to parse columns as JSON")
// -J, --no-json
fs.Var(cli.Bool(&opts.SkipJSON).WithFile(), "no-json",
"do not attempt to parse any column as JSON")
fs.Var(cli.Forward(fs, "no-json"), "J",
"do not attempt to parse any column as JSON")
// -x, --transaction
fs.Var(cli.Bool(&inTransaction).WithFile(), "transaction",
"run queries inside of a transaction")
fs.Var(cli.Forward(fs, "transaction"), "x",
"run queries inside of a transaction")
fs.Var(cli.NegBool(fs, "transaction").WithFile(), "X",
"do not run queries inside of a transaction")
fs.Var(cli.NegBool(fs, "transaction").WithFile(), "no-transaction",
"do not run queries inside of a transaction")
// -i, --isolation-level LEVEL
setIsolationLevel := func(s string) error {
orig := s
s = strings.ToLower(s)
s = strings.ReplaceAll(s, "_", "-")
switch s {
case "default", "d", "def":
isolationLevel = sql.LevelDefault
case "read-uncommitted", "readuncommitted", "ru", "dirty", "dirty-read", "dirty-reads":
isolationLevel = sql.LevelReadUncommitted
case "read-committed", "readcommitted", "rc":
isolationLevel = sql.LevelReadCommitted
case "write-committed", "writecommitted", "wc":
isolationLevel = sql.LevelWriteCommitted
case "repeatable-read", "repeatable", "repeatableread", "rr":
isolationLevel = sql.LevelRepeatableRead
case "snapshot", "snap":
isolationLevel = sql.LevelSnapshot
case "serializable", "s", "sz", "max":
isolationLevel = sql.LevelSerializable
case "linearizable", "l", "lin", "lz":
isolationLevel = sql.LevelLinearizable
default:
return fmt.Errorf("invalid isolation level %q", orig)
}
return nil
}
fs.Var(cli.NewFlagFunc("default", false,
setIsolationLevel).WithFile(), "isolation-level", "set the `isolation` level of the transaction (-x, -transaction only)")
fs.Var(cli.Forward(fs, "isolation-level"), "i", "set the `isolation` level of the transaction (-x, -transaction only)")
// -f, --config FILE
setConfigFile := func(s string) error {
configFilePaths = append(configFilePaths, s)
return nil
}
fs.Var(cli.NewFlagFunc("", false, setConfigFile), "config", "load configuration from a config `file`")
fs.Var(cli.Forward(fs, "config"), "f", "load configuration from a config `file`")
// -F, --no-config
fs.BoolVar(&noConfig, "no-config", noConfig, "disable config file loading")
fs.Var(cli.Forward(fs, "no-config"), "F", "disable config file loading")
var (
username string
usernameOpt = cli.Optional(cli.String(&username).WithFile())
password string
passwordOpt = cli.Optional(cli.String(&password).WithFile())
)
// -u, --username USERNAME
fs.Var(usernameOpt, "u", "username")
fs.Var(usernameOpt, "username", "username")
// -p, --password PASSWORD
fs.Var(passwordOpt, "P", "password")
fs.Var(passwordOpt, "password", "password")
// -t, --time-format FORMAT
timeFlag := cli.File(cli.NewFlagFunc("rfc3339", false, func(s string) error {
err := opts.TimeFormat.UnmarshalText([]byte(s))
if err == nil {
return nil
}
if opts.TimeFormat == vdb.TimeCustom {
// Ignore TimeCustom because it's for programmatic use
// (i.e., below).
return fmt.Errorf("invalid time format: %q", s)
}
if strings.HasPrefix(s, "+") {
opts.TimeFormat, opts.TimeLayout = vdb.TimeCustom, s[1:]
return nil
} else if strings.HasPrefix(s, "format:") {
opts.TimeFormat, opts.TimeLayout = vdb.TimeCustom, s[7:]
return nil
}
return fmt.Errorf("invalid time format: %w", err)
}))
fs.Var(timeFlag, "t", "set the `format` of parsed times")
fs.Var(timeFlag, "time-format", "set the `format` of parsed times")
if err := fs.Parse(argv); errors.Is(err, flag.ErrHelp) {
os.Exit(2)
} else if err != nil {
log.Fatalf("Error parsing arguments: %v", err)
}
if noConfig {
configFilePaths = []string{}
}
if configFilePaths == nil {
xdgConfig := ""
if p, err := xdg.SearchConfigFile("sql/sqlrc"); err == nil && p != "" {
xdgConfig = p
}
homeConfig := ""
if e, ok := os.LookupEnv("HOME"); ok && e != "" {
homeConfig = filepath.Join(e, ".sqlrc")
}
wd, err := os.Getwd()
for prev := ""; err == nil && wd != prev; prev, wd = wd, filepath.Dir(wd) {
p := filepath.Join(wd, ".sqlrc")
if p == homeConfig || p == xdgConfig {
continue
}
if s, err := os.Stat(p); errors.Is(err, os.ErrNotExist) || s.IsDir() {
continue
}
configFilePaths = append(configFilePaths, p)
}
if xdgConfig != "" {
configFilePaths = append(configFilePaths, xdgConfig)
}
if homeConfig != "" {
configFilePaths = append(configFilePaths, homeConfig)
}
}
for _, p := range configFilePaths {
if err := loadConfigFile(p, &dsnArg, fs); errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
log.Fatalf("Fatal error loading config file: %v", err)
}
}
opts.TryJSON = opts.TryJSON && !opts.SkipJSON
if sep == "" {
log.Fatalf("Separator (-s, --sep) cannot be empty")
} else if sep[0] == '@' {
log.Fatalf("Separator (-s, --sep) may not begin with @")
}
argv = fs.Args()
escSep := `\` + sep
escArgSep := `\` + argSep
if argSep != "" && argSep[0] == '@' {
log.Fatalf("Argument separator (-d, --arg-sep) may not begin with @")
}
if len(argv) == 0 {
log.Fatalf("No DSN argument provided.")
}
if dsnArg == "" || argv[0] != "-" {
dsnArg = argv[0]
}
argv = argv[1:]
var u *url.URL
err := cli.File(cli.NewFlagFunc("", false, func(v string) error {
dsnURL, err := url.Parse(v)
if err != nil {
return errors.New("error parsing DSN URL")
}
u = dsnURL
return nil
})).Set(dsnArg)
if err != nil {
panic(err)
}
if u.User != nil {
if !usernameOpt.IsSet() {
username = u.User.Username()
}
if !passwordOpt.IsSet() {
password, _ = u.User.Password()
}
}
if usernameOpt.IsSet() || passwordOpt.IsSet() {
u.User = url.UserPassword(username, password)
}
driver, dsn, bindType, err := driver.DSNFromURL(u)
if err != nil {
panic(err)
}
opts.BindType = bindType
db, err := sql.Open(driver, dsn)
if err != nil {
panic(err)
}
defer db.Close()
bufout := bufio.NewWriter(os.Stdout)
flush := func() {
if err := bufout.Flush(); err != nil {
log.Printf("error flushing output: %v", err)
}
}
defer flush()
enc := json.NewEncoder(bufout)
if !opts.Compact {
enc.SetIndent("", " ")
}
ctx := context.Background()
argSets := [][]string{}
splitArgs:
for i, arg := range argv {
if arg == sep {
argSets = append(argSets, argv[:i])
argv = argv[i+1:]
goto splitArgs
} else if arg == escSep {
argv[i] = sep
}
}
if len(argv) > 0 {
argSets = append(argSets, argv)
}
var queries []*vdb.Execution
for i, args := range argSets {
sets := [][]string{}
begin := 1
narg := 0
for j, arg := range args {
if j > 0 && arg == escArgSep {
args[j] = argSep
continue
}
if j > 0 && arg == argSep {
set := args[begin:j]
sets = append(sets, set)
narg += len(set)
begin = j + 1
continue
}
args[j] = arg
}
if begin < len(args) {
set := args[begin:]
narg += len(set)
sets = append(sets, set)
}
queryStmt, err := cli.FromFile(args[0])
if err != nil {
panic(fmt.Errorf("error parsing query %d statement(s): %w", i, err))
}
qargs, err := parseQueryArgs(sets)
if err != nil {
panic(fmt.Errorf("error parsing query %d arguments: %w", i, err))
}
query := &vdb.Execution{
Query: queryStmt,
Args: qargs,
Options: &opts,
}
queries = append(queries, query)
}
var conn vdb.DB = db
done := func(err *error) {}
if inTransaction {
tx, err := db.BeginTx(ctx, &sql.TxOptions{
Isolation: isolationLevel,
})
if err != nil {
panic(fmt.Errorf("failed to initiate transaction: %w", err))
}
conn = tx
done = func(err *error) {
rc := recover()
if rc != nil {
defer panic(rc)
}
fatal := *err != nil || rc != nil
op, verb := tx.Commit, "commit"
if fatal {
op, verb = tx.Rollback, "rollback"
}
if err := op(); err != nil {
log.Printf("failed to %s transaction: %w", verb, err)
}
}
}
_ = ctx
_ = opts
err = func() (err error) {
defer done(&err)
for i, query := range queries {
i := i + 1
results, err := query.Exec(ctx, conn)
if err != nil {
return fmt.Errorf("error executing query %d: %w", i, err)
}
for ri, recs := range results {
if printArray {
if err = enc.Encode(recs); err != nil {
return fmt.Errorf("error encoding result set %d:%d: %w", i, ri, err)
}
continue
}
for rci, rec := range recs {
if err = enc.Encode(rec); err != nil {
return fmt.Errorf("error encoding result %d:%d:%d: %w", i, ri, rci, err)
}
}
}
flush()
}
return nil
}()
if err != nil {
panic(fmt.Errorf("fatal error: %w", err))
}
}
func parseQueryArgs(sets [][]string) ([][]interface{}, error) {
const (
stackOpen = "-{"
stackClose = "}-"
)
type argStack struct {
car *argStack
args []interface{}
}
minArgs := len(sets)
if minArgs == 0 {
minArgs = 1
}
qargs := make([][]interface{}, minArgs)
for i, set := range sets {
root := argStack{
args: make([]interface{}, 0, len(set)),
}
top := &root
for j, v := range set {
switch v {
case stackOpen:
next := top.args
if top.car == nil {
next = make([]interface{}, 0, 8)
}
top = &argStack{
car: top,
args: next,
}
case stackClose:
if top.car == nil {
return nil, fmt.Errorf("unexpected %q in arg set %d, arg %d: missing leading \\ to escape?", v, i, j)
}
args := top.args
top = top.car
if top.car == nil {
top.args = append(top.args, args)
} else {
top.args = args
}
default:
data, err := toConcrete(v)
if err != nil {
return nil, fmt.Errorf("error parsing %q in arg set %d, arg %d: %w", v, i, j, err)
}
top.args = append(top.args, data)
}
}
if top.car != nil {
return nil, fmt.Errorf("missing %q in arg set %d: %q", stackClose, i, set)
}
for i, arg := range root.args {
root.args[i] = flatten(arg)
}
qargs[i] = root.args
}
return qargs, nil
}
func flatten(v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
f := make([]interface{}, 0, len(v))
for _, av := range v {
av = flatten(av)
if a, ok := av.([]interface{}); ok {
f = append(f, a...)
} else {
f = append(f, av)
}
}
return f
case map[string]interface{}:
b, _ := json.Marshal(v)
return string(b)
default:
return v
}
}
func toConcrete(v string) (interface{}, error) {
const (
sep = `::`
esc = `\`
)
if strings.HasPrefix(v, esc) {
return cli.FromFile(v)
}
tsep := strings.Index(v, sep)
if tsep <= 0 {
return cli.FromFile(v)
}
typ, str := v[:tsep], v[tsep+len(sep):]
str, err := cli.FromFile(str)
if err != nil {
return nil, err
}
switch strings.ToLower(typ) {
case "str", "s":
return str, nil
case "bytes", "bs":
return []byte(str), nil
case "int", "i", "l":
return strconv.ParseInt(str, 0, 64)
case "uint", "u", "ul":
return strconv.ParseUint(str, 0, 64)
case "array", "a": // JSON array.
ary := []interface{}{}
if err := json.Unmarshal([]byte(str), &ary); err != nil {
return nil, err
}
for i, p := range ary {
switch p.(type) {
case map[string]interface{}, []interface{}:
b, _ := json.Marshal(p)
ary[i] = string(b)
}
}
return ary, err
case "json", "j": // JSON stream -- not an object, necessarily.
ary := make([]interface{}, 0, 8)
dec := json.NewDecoder(strings.NewReader(str))
for i := 1; ; i++ {
var p interface{}
if err := dec.Decode(&p); errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, fmt.Errorf("error parsing %d-th JSON item from argument: %w", i, err)
}
switch p.(type) {
case map[string]interface{}, []interface{}:
b, _ := json.Marshal(p)
p = string(b)
}
ary = append(ary, p)
}
if len(ary) == 1 {
return ary[0], nil
}
return ary, nil
case "sh":
strs, err := shlex.Split(str)
if err != nil {
return nil, err
}
return strs, nil
case "fields", "fs":
strs, err := shlex.Split(str)
if err != nil {
return nil, err
}
ary := make([]interface{}, len(strs))
for i, fstr := range strs {
ary[i], err = toConcrete(fstr)
if err != nil {
return nil, fmt.Errorf("error expanding field %d in arguments: %w", i+1, err)
}
}
return ary, nil
case "openfile", "of":
data, err := ioutil.ReadFile(str)
if err != nil {
return nil, fmt.Errorf("unable to read file %q: %w", str, err)
}
return string(data), nil
case "rawfile", "rf":
data, err := ioutil.ReadFile(str)
if err != nil {
return nil, fmt.Errorf("unable to read file %q: %w", str, err)
}
return data, nil
case "fd":
fd, err := strconv.ParseUint(str, 10, 64)
if err != nil {
return nil, fmt.Errorf("unable to parse file descriptor %q: %w", str, err)
}
f := os.NewFile(uintptr(fd), v)
defer f.Close()
return ioutil.ReadAll(f)
case "float", "double", "single", "real", "d", "f":
return strconv.ParseFloat(str, 64)
case "bool", "boolean", "b":
return strconv.ParseBool(str)
}
return v, nil // Do not parse -- there is no associated type here.
}