~ninjin/julia-nix

68e08136550c928de028d036413775e74ca90c3e — Jameson Nash 2 years ago 4a048d3
fix collect on stateful iterators

Generalization of #41919
Fixes #42168
5 files changed, 52 insertions(+), 34 deletions(-)

M base/array.jl
M base/dict.jl
M base/set.jl
M src/julia-syntax.scm
M test/iterators.jl
M base/array.jl => base/array.jl +39 -22
@@ 643,23 643,38 @@ julia> collect(Float64, 1:2:5)
"""
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))

_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_collect(::Type{T}, itr, isz::HasShape) where {T}  = copyto!(similar(Array{T}, axes(itr)), itr)
_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} =
    copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr)
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
    a = Vector{T}()
    for x in itr
        push!(a,x)
        push!(a, x)
    end
    return a
end

# make a collection similar to `c` and appropriate for collecting `itr`
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} =
    similar(c, T, Int(length(itr)::Integer))
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} =
    similar(c, T, axes(itr))
_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T)
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)

_similar_shape(itr, ::SizeUnknown) = nothing
_similar_shape(itr, ::HasLength) = length(itr)::Integer
_similar_shape(itr, ::HasShape) = axes(itr)

_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} =
    similar(c, T, 0)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} =
    similar(c, T, len)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} =
    similar(c, T, axs)

# make a collection appropriate for collecting `itr::Generator`
_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0)
_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len))
_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

# used by syntax lowering for simple typed comprehensions
_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz))


"""
    collect(collection)


@@ 698,10 713,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
    copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
    copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
    a = _similar_for(cont, eltype(itr), itr, isz)
    a = _similar_for(cont, eltype(itr), itr, isz, nothing)
    for x in itr
        push!(a,x)
    end


@@ 759,24 774,19 @@ else
    end
end

_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

function collect(itr::Generator)
    isz = IteratorSize(itr.iter)
    et = @default_eltype(itr)
    if isa(isz, SizeUnknown)
        return grow_to!(Vector{et}(), itr)
    else
        shape = isz isa HasLength ? length(itr) : axes(itr)
        shp = _similar_shape(itr, isz)
        y = iterate(itr)
        if y === nothing
            return _array_for(et, itr.iter, isz)
            return _array_for(et, isz, shp)
        end
        v1, st = y
        dest = _array_for(typeof(v1), itr.iter, isz, shape)
        dest = _array_for(typeof(v1), isz, shp)
        # The typeassert gives inference a helping hand on the element type and dimensionality
        # (work-around for #28382)
        et′ = et <: Type ? Type : et


@@ 786,15 796,22 @@ function collect(itr::Generator)
end

_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
    grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr)
    grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr)

function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
    et = @default_eltype(itr)
    shp = _similar_shape(itr, isz)
    y = iterate(itr)
    if y === nothing
        return _similar_for(c, @default_eltype(itr), itr, isz)
        return _similar_for(c, et, itr, isz, shp)
    end
    v1, st = y
    collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
    dest = _similar_for(c, typeof(v1), itr, isz, shp)
    # The typeassert gives inference a helping hand on the element type and dimensionality
    # (work-around for #28382)
    et′ = et <: Type ? Type : et
    RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
    collect_to_with_first!(dest, v1, itr, st)::RT
end

function collect_to_with_first!(dest::AbstractArray, v1, itr, st)

M base/dict.jl => base/dict.jl +2 -2
@@ 826,6 826,6 @@ length(t::ImmutableDict) = count(Returns(true), t)
isempty(t::ImmutableDict) = !isdefined(t, :parent)
empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}()

_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V)
_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} =
_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V)
_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =
    throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n  if calling map, consider a comprehension instead"))

M base/set.jl => base/set.jl +1 -1
@@ 44,7 44,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
# by default, a Set is returned
emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()

_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T)
_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T)

function show(io::IO, s::Set)
    if isempty(s)

M src/julia-syntax.scm => src/julia-syntax.scm +3 -5
@@ 2734,7 2734,7 @@
  (check-no-return expr)
  (if (has-break-or-continue? expr)
      (error "break or continue outside loop"))
  (let ((result    (gensy))
  (let ((result    (make-ssavalue))
        (idx       (gensy))
        (oneresult (make-ssavalue))
        (prod      (make-ssavalue))


@@ 2758,16 2758,14 @@
    (let ((overall-itr (if (length= itrs 1) (car iv) prod)))
      `(scope-block
        (block
         (local ,result) (local ,idx)
         (local ,idx)
         ,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs)
         ,.(if (length= itrs 1)
               '()
               `((= ,prod (call (top product) ,@iv))))
         (= ,isz (call (top IteratorSize) ,overall-itr))
         (= ,szunk (call (core isa) ,isz (top SizeUnknown)))
         (if ,szunk
             (= ,result (call (curly (core Array) ,ty 1) (core undef) 0))
             (= ,result (call (top _array_for) ,ty ,overall-itr ,isz)))
         (= ,result (call (top _array_for) ,ty ,overall-itr ,isz))
         (= ,idx (call (top first) (call (top LinearIndices) ,result)))
         ,(construct-loops (reverse itrs) (reverse iv))
         ,result)))))

M test/iterators.jl => test/iterators.jl +7 -4
@@ 293,11 293,14 @@ let (a, b) = (1:3, [4 6;
end

# collect stateful iterator
let
    itr = (i+1 for i in Base.Stateful([1,2,3]))
let itr
    itr = Iterators.Stateful(Iterators.map(identity, 1:5))
    @test collect(itr) == 1:5
    @test collect(itr) == Int[] # Stateful do not preserve shape
    itr = (i+1 for i in Base.Stateful([1, 2, 3]))
    @test collect(itr) == [2, 3, 4]
    A = zeros(Int, 0, 0)
    itr = (i-1 for i in Base.Stateful(A))
    @test collect(itr) == Int[] # Stateful do not preserve shape
    itr = (i-1 for i in Base.Stateful(zeros(Int, 0, 0)))
    @test collect(itr) == Int[] # Stateful do not preserve shape
end