~ninjin/julia-nix

72425df7ffb715de5feb1a831d1239bfbca9bf6b — Jameson Nash 2 years ago e126fcd
lowering: copy expand-if optimization to expand-while (#42194)

Co-authored-by: Simeon Schaub <simeondavidschaub99@gmail.com>
2 files changed, 59 insertions(+), 42 deletions(-)

M src/julia-syntax.scm
M test/compiler/inference.jl
M src/julia-syntax.scm => src/julia-syntax.scm +44 -42
@@ 1956,21 1956,28 @@
                (else
                 (error (string "invalid " syntax-str " \"" (deparse el) "\""))))))))

(define (expand-if e)
  (let* ((test (cadr e))
         (blk? (and (pair? test) (eq? (car test) 'block)))
         (stmts (if blk? (cdr (butlast test)) '()))
         (test  (if blk? (last test) test)))
(define (expand-condition cnd)
  (let* ((blk? (and (pair? cnd) (eq? (car cnd) 'block)))
         (stmts (if blk? (cdr (butlast cnd)) '()))
         (test  (if blk? (last cnd) cnd)))
    (if (and (pair? test) (memq (car test) '(&& |\|\||)))
        (let* ((clauses `(,(car test) ,@(map expand-forms (cdr (flatten-ex (car test) test)))))
               (clauses (if (null? (cdr clauses))
                            (if (eq? (car clauses) '&&) '(true) '(false))
                            clauses)))
          `(if ,(if blk?
                    `(block ,@(map expand-forms stmts) ,clauses)
                    clauses)
               ,@(map expand-forms (cddr e))))
        (cons (car e) (map expand-forms (cdr e))))))
          (if blk?
              `(block ,@(map expand-forms stmts) ,clauses)
              clauses))
        (expand-forms cnd))))

(define (expand-if e)
  (list* (car e) (expand-condition (cadr e)) (map expand-forms (cddr e))))

(define (expand-while e)
  `(break-block loop-exit
                (_while ,(expand-condition (cadr e))
                        (break-block loop-cont
                                     (scope-block ,(blockify (expand-forms (caddr e))))))))

(define (expand-vcat e
                     (vcat '((top vcat)))


@@ 2565,13 2572,7 @@

   'if expand-if
   'elseif expand-if

   'while
   (lambda (e)
     `(break-block loop-exit
                   (_while ,(expand-forms (cadr e))
                           (break-block loop-cont
                                        (scope-block ,(blockify (expand-forms (caddr e))))))))
   'while expand-while

   'break
   (lambda (e)


@@ 4205,6 4206,29 @@ f(x) = yt(x)
              (emit `(= ,tmp ,cnd))
              tmp)
            cnd)))
    (define (emit-cond cnd break-labels endl)
      (let* ((cnd (if (and (pair? cnd) (eq? (car cnd) 'block))
                       (begin (if (length> cnd 2) (compile (butlast cnd) break-labels #f #f))
                              (last cnd))
                       cnd))
             (or? (and (pair? cnd) (eq? (car cnd) '|\|\||)))
             (tests (if or?
                        (let ((short-circuit `(goto _)))
                          (for-each
                            (lambda (clause)
                              (let ((jmp (emit `(gotoifnot ,(compile-cond clause break-labels) ,endl))))
                                (emit short-circuit)
                                (set-car! (cddr jmp) (make&mark-label))))
                            (butlast (cdr cnd)))
                          (let ((last-jmp (emit `(gotoifnot ,(compile-cond (last (cdr cnd)) break-labels) ,endl))))
                            (set-car! (cdr short-circuit) (make&mark-label))
                            (list last-jmp)))
                        (map (lambda (clause)
                               (emit `(gotoifnot ,(compile-cond clause break-labels) ,endl)))
                             (if (and (pair? cnd) (eq? (car cnd) '&&))
                                 (cdr cnd)
                                 (list cnd))))))
          tests))
    (define (emit-assignment lhs rhs)
      (if rhs
          (if (valid-ir-rvalue? lhs rhs)


@@ 4345,28 4369,7 @@ f(x) = yt(x)
                 (compile (cadr e) break-labels value tail)
                 #f))
            ((if elseif)
             (let* ((cnd (cadr e))
                    (cnd (if (and (pair? cnd) (eq? (car cnd) 'block))
                              (begin (if (length> cnd 2) (compile (butlast cnd) break-labels #f #f))
                                     (last cnd))
                              cnd))
                    (or? (and (pair? cnd) (eq? (car cnd) '|\|\||)))
                    (tests (if or?
                               (let ((short-circuit `(goto _)))
                                 (for-each
                                   (lambda (clause)
                                     (let ((jmp (emit `(gotoifnot ,(compile-cond clause break-labels) _))))
                                       (emit short-circuit)
                                       (set-car! (cddr jmp) (make&mark-label))))
                                   (butlast (cdr cnd)))
                                 (let ((last-jmp (emit `(gotoifnot ,(compile-cond (last (cdr cnd)) break-labels) _))))
                                   (set-car! (cdr short-circuit) (make&mark-label))
                                   (list last-jmp)))
                               (map (lambda (clause)
                                      (emit `(gotoifnot ,(compile-cond clause break-labels) _)))
                                    (if (and (pair? cnd) (eq? (car cnd) '&&))
                                        (cdr cnd)
                                        (list cnd)))))
             (let* ((tests (emit-cond (cadr e) break-labels '_))
                    (end-jump `(goto _))
                    (val (if (and value (not tail)) (new-mutable-var) #f)))
               (let ((v1 (compile (caddr e) break-labels value tail)))


@@ 4388,9 4391,8 @@ f(x) = yt(x)
                   val))))
            ((_while)
             (let* ((endl (make-label))
                    (topl (make&mark-label))
                    (test (compile-cond (cadr e) break-labels)))
               (emit `(gotoifnot ,test ,endl))
                    (topl (make&mark-label)))
               (emit-cond (cadr e) break-labels endl)
               (compile (caddr e) break-labels #f #f)
               (emit `(goto ,topl))
               (mark-label endl))

M test/compiler/inference.jl => test/compiler/inference.jl +15 -0
@@ 2889,6 2889,21 @@ function symcmp36230(vec)
end
@test Base.return_types(symcmp36230, (Vector{Any},)) == Any[Bool]

function foo42190(r::Union{Nothing,Int}, n::Int)
    while r !== nothing && r < n
        return r # `r::Int`
    end
    return n
end
@test Base.return_types(foo42190, (Union{Nothing, Int}, Int)) == Any[Int]
function bar42190(r::Union{Nothing,Int}, n::Int)
    while r === nothing || r < n
        return n
    end
    return r # `r::Int`
end
@test Base.return_types(bar42190, (Union{Nothing, Int}, Int)) == Any[Int]

# Issue #36531, double varargs in abstract_iteration
f36531(args...) = tuple((args...)...)
@test @inferred(f36531(1,2,3)) == (1,2,3)