a39c4d00e3afa4b8e3c0e0fc751990886b2caead — Phil Hagelberg 1 year, 1 month ago c69a725 master
Prevent early evaluation in method calls and comparators.

Previously in certain situations we would use the internal compiler
helper function `once` in order to ensure a value only got evaluated
once, to prevent double-evaluation in method calls and

Unfortunately this had a bug where in an effort to prevent
double-evaluation, it would evaluate it exactly once, rather than at
most once, because it would emit the expression evaluation and bind it
to a local in the parent context and then use the local in multiple

This would cause problems because there are cases where it should not
be evaluated at all:

    (and false (: (foobar) :method?))

    (and false (< (a) (b) (c)))

Both these cases are special because they cannot compile into the
"lua-native" equivalent construct; in the first case because the
method name has a question mark in it, and in the second case because
the < operator has three args instead of two.

We fix this by compiling to an IIFE and using its arguments to refer
to values twice without danger of double-evaluation:

    (function(tgt, m, ...) return tgt[m](tgt, ...) end)(%s, %s)

Of course, this introduces some overhead, but it is the only safe way
to compile the code. In both cases we check to see if the lua-native
compilation approach is possible, and if it is we use it, so the
overhead will only be incurred when absolutely necessary.

Fixes #323
M changelog.md => changelog.md +3 -0
@@ 2,6 2,9 @@

## 0.5.1 / ???

* Fix a bug where method calls would early-evaluate their receiver.
* Fix a bug where multi-arity comparisons would early-evaluate their arguments.

## 0.5.0 / 2020-08-08

* Fix a bug where lambdas with no body would return true instead of nil.

M src/fennel/compiler.fnl => src/fennel/compiler.fnl +3 -3
@@ 65,7 65,7 @@ The ast arg should be unmodified so that its first element is the form called."
(fn global-mangling [str]
  "Mangler for global symbols. Does not protect against collisions,
but makes them unlikely. This is the mangling that is exposed to to the world."
  (if (utils.is-valid-lua-identifier str)
  (if (utils.valid-lua-identifier? str)
      (.. "__fnl_global__"
          (: str :gsub "[^%w]" #(: "_%02x" :format (: $ "byte"))))))

@@ 127,7 127,7 @@ these new manglings instead of the current manglings."
  "Combine parts of a symbol."
  (var ret (or (. scope.manglings (. parts 1)) (global-mangling (. parts 1))))
  (for [i 2 (# parts) 1]
    (if (utils.is-valid-lua-identifier (. parts i))
    (if (utils.valid-lua-identifier? (. parts i))
        (if (and parts.multi-sym-method-call (= i (# parts)))
            (set ret (.. ret ":" (. parts i)))
            (set ret (.. ret "." (. parts i))))

@@ 518,7 518,7 @@ which we have to do if we don't know."
            (when (or (not= (type k) "number")
                      (not= (math.floor k) k)
                      (< k 1) (> k (# ast)))
              (if (and (= (type k) "string") (utils.is-valid-lua-identifier k))
              (if (and (= (type k) "string") (utils.valid-lua-identifier? k))
                  [k k]
                  (let [[compiled] (compile1 k scope parent {:nval 1})
                        kstr (.. "[" (tostring compiled) "]")]

M src/fennel/macros.fnl => src/fennel/macros.fnl +6 -6
@@ 136,15 136,15 @@ that argument name begins with ?."
      (if (table? a)
          (each [_ a (pairs a)]
            (check! a))
          (and (not (: (tostring a) :match "^?"))
          (and (not (string.match (tostring a) "^?"))
               (not= (tostring a) "&")
               (not= (tostring a) "..."))
          (table.insert args arity-check-position
                        `(assert (not= nil ,a)
                                 (: "Missing argument %s on %s:%s"
                                    :format ,(tostring a)
                                    ,(or a.filename "unknown")
                                    ,(or a.line "?"))))))
                                 (string.format "Missing argument %s on %s:%s"
                                                ,(tostring a)
                                                ,(or a.filename "unknown")
                                                ,(or a.line "?"))))))
    (assert (= :table (type arglist)) "expected arg list")
    (each [_ a (ipairs arglist)]
      (check! a))

@@ 222,7 222,7 @@ introduce for the duration of the body if it does match."
        (sym? pattern)
        (let [wildcard? (= (tostring pattern) "_")]
          (if (not wildcard?) (tset unifications (tostring pattern) val))
          (values (if (or wildcard? (: (tostring pattern) :find "^?"))
          (values (if (or wildcard? (string.find (tostring pattern) "^?"))
                      true `(not= ,(sym :nil) ,val))
                  [pattern val]))
        ;; guard clause

M src/fennel/specials.fnl => src/fennel/specials.fnl +77 -55
@@ 271,7 271,7 @@ and lacking args will be nil, use lambda for arity-checked functions."))
          (for [i 3 len 1]
            (var index (. ast i))
            (if (and (= (type index) "string")
                     (utils.is-valid-lua-identifier index))
                     (utils.valid-lua-identifier? index))
                (table.insert indices (.. "." index))
                  (set index (. (compiler.compile1 index scope parent {:nval 1})

@@ 570,43 570,51 @@ order, but can be used with any iterator.")
 "Numeric loop construct.
Evaluates body once for each value between start and stop (inclusive).")

(fn once [val ast scope parent]
  (if (or (= val.type "statement") (= val.type "expression"))
      (let [s (compiler.gensym scope)]
        (compiler.emit parent (: "local %s = %s" :format s (tostring val)) ast)
        (utils.expr s "sym"))
(fn native-method-call [ast scope parent target args]
  "Prefer native Lua method calls when method name is a valid Lua identifier."
  (let [[_ _ method-string] ast
        call-string (if (or (= target.type :literal) (= target.type :expression))
    (utils.expr (string.format call-string (tostring target) method-string
                               (table.concat args ", ")) "statement")))

(fn nonnative-method-call [ast scope parent target args]
  "When we don't have to protect against double-evaluation, it's not so bad."
  (let [method-string (tostring (. (compiler.compile1 (. ast 3) scope parent
                                                      {:nval 1}) 1))]
    (table.insert args (tostring target))
    (utils.expr (string.format "%s[%s](%s)" (tostring target) method-string
                               (tostring target)
                               (table.concat args ", ")) "statement")))

(fn double-eval-protected-method-call [ast scope parent target args]
  "When double-evaluation is a concern, we have to wrap an IIFE."
  (let [method-string (tostring (. (compiler.compile1 (. ast 3) scope parent
                                                      {:nval 1}) 1))
        call "(function(tgt, m, ...) return tgt[m](tgt, ...) end)(%s, %s)"]
    (table.insert args method-string)
    (utils.expr (string.format call (tostring target) (table.concat args ", "))

(fn method-call [ast scope parent]
  (compiler.assert (>= (# ast) 3) "expected at least 2 arguments" ast)
  ;; Compile object
  (var objectexpr (. (compiler.compile1 (. ast 2) scope parent {:nval 1}) 1))
  (var (methodident methodstring) false)
  (if (and (= (type (. ast 3)) "string") (utils.is-valid-lua-identifier (. ast 3)))
        (set methodident true)
        (set methodstring (. ast 3)))
        (set methodstring (tostring (. (compiler.compile1 (. ast 3) scope parent
                                                          {:nval 1}) 1)))
        (set objectexpr (once objectexpr (. ast 2) scope parent))))
  (let [args []] ; compile arguments
    (for [i 4 (# ast) 1]
  (compiler.assert (< 2 (# ast)) "expected at least 2 arguments" ast)
  (let [[target] (compiler.compile1 (. ast 2) scope parent {:nval 1})
        args []]
    (for [i 4 (# ast)]
      (let [subexprs (compiler.compile1 (. ast i) scope parent
                                        {:nval (if (not= i (# ast)) 1 nil)})]
                                        {:nval (if (not= i (# ast)) 1)})]
        (utils.map subexprs tostring args)))
    (var fstring nil)
    (if (not methodident)
        (do ; make object the first argument
          (table.insert args 1 (tostring objectexpr))
          (set fstring (if (= objectexpr.type "sym")
        (or (= objectexpr.type "literal") (= objectexpr.type "expression"))
        (set fstring "(%s):%s(%s)")
        (set fstring "%s:%s(%s)"))
    (utils.expr (: fstring :format (tostring objectexpr) methodstring
                   (table.concat args ", ")) "statement")))
    (if (and (= (type (. ast 3)) :string) (utils.valid-lua-identifier? (. ast 3)))
        (native-method-call ast scope parent target args)
        (= target.type :sym)
        (nonnative-method-call ast scope parent target args)
        ;; When the target is an expression, we can't use the naive
        ;; nonnative-method-call approach, because it will cause the target
        ;; to be evaluated twice. This is fine if it's a symbol but if it's
        ;; the result of a function call, that function could have side-effects.
        ;; See test-short-circuit in test/misc.fnl for an example of the problem.
        (double-eval-protected-method-call ast scope parent target args))))

(tset SPECIALS ":" method-call)

@@ 723,30 731,44 @@ Method name doesn't have to be known at compile-time; if it is, use
(doc-special ".." ["a" "b" "..."]
            "String concatenation operator; works the same as Lua but accepts more arguments.")

(fn define-comparator-special [name realop chain-op]
  (let [op (or realop name)]
(fn native-comparator [op [_ lhs-ast rhs-ast] scope parent]
  "Naively compile a binary comparison to Lua."
  (let [[lhs] (compiler.compile1 lhs-ast scope parent {:nval 1})
        [rhs] (compiler.compile1 rhs-ast scope parent {:nval 1})]
    (string.format "(%s %s %s)" (tostring lhs) op (tostring rhs))))

(fn double-eval-protected-comparator [op chain-op ast scope parent]
  "Compile a multi-arity comparison to a binary Lua comparison."
  (let [arglist [] comparisons [] vals []
        chain (string.format " %s " (or chain-op "and"))]
    (for [i 2 (# ast)]
      (table.insert arglist (tostring (compiler.gensym scope)))
      (table.insert vals (tostring (. (compiler.compile1 (. ast i) scope parent
                                                         {:nval 1}) 1))))
    (for [i 1 (- (# arglist) 1)]
      (table.insert comparisons (string.format "(%s %s %s)"
                                               (. arglist i) op
                                               (. arglist (+ i 1)))))
    ;; The function call here introduces some overhead, but it is the only way
    ;; to compile this safely while preventing both double-evaluation of
    ;; side-effecting values and early evaluation of values which should never
    ;; happen in the case of a short-circuited call. See test-short-circuit in
    ;; test/misc.fnl for an example of the problem.
    (string.format "(function(%s) return %s end)(%s)"
                   (table.concat arglist ",")
                   (table.concat comparisons chain)
                   (table.concat vals ","))))

(fn define-comparator-special [name lua-op chain-op]
  (let [op (or lua-op name)]
    (fn opfn [ast scope parent]
      (local len (# ast))
      (compiler.assert (> len 2) "expected at least two arguments" ast)
      (local lhs (. (compiler.compile1 (. ast 2) scope parent {:nval 1}) 1))
      (var lastval (. (compiler.compile1 (. ast 3) scope parent {:nval 1}) 1))
      (when (> len 3) ; avoid double-eval by adding locals for side-effects
        (set lastval (once lastval (. ast 3) scope parent)))
      (var out (: "(%s %s %s)" :format (tostring lhs) op (tostring lastval)))
      (when (> len 3)
        (for [i 4 len] ; variadic comparison
          (let [nextval (once (. (compiler.compile1 (. ast i)
                                                    scope parent
                                                    {:nval 1}) 1)
                              (. ast i) scope parent)]
            (set out (: (.. out " %s (%s %s %s)") :format (or chain-op "and")
                        (tostring lastval) op (tostring nextval)))
            (set lastval nextval)))
        (set out (.. "(" out ")")))
      (compiler.assert (< 2 (# ast)) "expected at least two arguments" ast)
      (if (= 3 (# ast))
          (native-comparator op ast scope parent)
          (double-eval-protected-comparator op chain-op ast scope parent)))
    (tset SPECIALS name opfn))
  (doc-special name ["a" "b" "..."]
     "Comparison operator; works the same as Lua but accepts more arguments."))
               "Comparison operator; works the same as Lua but accepts more arguments."))

(define-comparator-special ">")
(define-comparator-special "<")

M src/fennel/utils.fnl => src/fennel/utils.fnl +2 -2
@@ 206,7 206,7 @@ When f returns a truthy value, recursively walks the children."
(each [i v (ipairs lua-keywords)]
  (tset lua-keywords v i))

(fn is-valid-lua-identifier [str]
(fn valid-lua-identifier? [str]
  (and (str:match "^[%a_][%w_]*$") (not (. lua-keywords str))))

(local propagated-options [:allowedGlobals :indent :correlate :useMetadata :env])

@@ 233,7 233,7 @@ has options calls down into compile."
 : is-expr : is-list : is-multi-sym : is-sequence : is-sym : is-table : is-varg

 ;; other
 : is-valid-lua-identifier : lua-keywords
 : valid-lua-identifier? : lua-keywords
 : propagate-options : root : debug-on
 :path (table.concat (doto ["./?.fnl" "./?/init.fnl"]
                       (table.insert (getenv "FENNEL_PATH"))) ";")}

M test/misc.fnl => test/misc.fnl +11 -2
@@ 70,11 70,20 @@
    (l.assertNotNil broken-code "code should compile")
    (l.assertError broken-code "code should fail at runtime")))

(fn test-short-circuit []
  (let [method-code "(var shorted? false)
              (fn set-shorted! [] (set shorted? true) {:f! (fn [])})
              (and false (: (set-shorted!) :f!))
        comparator-code "(and false (< 1 (error :nein!) 3))"]
    (l.assertFalse (fennel.eval method-code))
    (l.assertFalse (fennel.eval comparator-code))))

{: test-empty-values
 : test-env-iteration
 : test-global-mangling
 : test-include
 : test-leak
 : test-runtime-quote
 : test-traceback}

 : test-traceback
 : test-short-circuit}