Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmented prime sieve for fun and profit #87

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Primes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ using Base: BitSigned
using Base.Checked: checked_neg

export isprime, primes, primesmask, factor, ismersenneprime, isrieselprime,
nextprime, nextprimes, prevprime, prevprimes, prime, prodfactors, radical, totient
nextprime, nextprimes, prevprime, prevprimes, prime, prodfactors, radical, totient,
SegmentedSieve

include("factorization.jl")
include("segmented_sieve/SegmentedSieve.jl")

# Primes generating functions
# https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
Expand Down
44 changes: 44 additions & 0 deletions src/segmented_sieve/SegmentedSieve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module SegmentedSieve

const ps = (1, 7, 11, 13, 17, 19, 23, 29)

"""
Population count of a vector of UInt8s for counting prime numbers.
See https://github.com/JuliaLang/julia/issues/34059
"""
function vec_count_ones(xs::Union{Vector{UInt8}, Base.FastContiguousSubArray{UInt8}})
n = length(xs)
count = 0
chunks = n ÷ sizeof(UInt)
GC.@preserve xs begin
ptr = Ptr{UInt}(pointer(xs))
for i in 1:chunks
count += count_ones(unsafe_load(ptr, i))
end
end

@inbounds for i in 8chunks+1:n
count += count_ones(xs[i])
end

count
end

function to_idx(x)
x == 1 && return 1
x == 7 && return 2
x == 11 && return 3
x == 13 && return 4
x == 17 && return 5
x == 19 && return 6
x == 23 && return 7
return 8
end

include("generate_sieving_loop.jl")
include("sieve_small.jl")
include("presieve.jl")
include("siever.jl")
include("sieve.jl")

end # module
145 changes: 145 additions & 0 deletions src/segmented_sieve/generate_sieving_loop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
For a prime number p and a multiple q, the wheel index encodes the prime number index of
p and q modulo 30, which can be encoded to a single number from 1 ... 64. This allows us
to jump immediately into the correct loop at the correct offset.
"""
create_jump(wheel_index, i) = :($wheel_index === $i && @goto $(Symbol(:x, i)))

create_label(wheel_index) = :(@label $(Symbol(:x, wheel_index)))

wheel_mask(prime_mod_30)::UInt8 = ~(0x01 << (to_idx(prime_mod_30) - 1))

"""
For any prime number `p` we compute its prime number index modulo 30 (here `wheel`) and we
generate the loop that crosses of the next 8 multiples that, modulo 30, are
p * {1, 7, 11, 13, 17, 19, 23, 29}.
"""
function unrolled_loop(p_idx)
p = ps[p_idx]

# First push the stopping criterion
unrolled_loop_body = Any[:(byte_idx > unrolled_max && break)]

# Cross off the 8 next multiples
for q in ps
div, rem = divrem(p * q, 30)
bit = wheel_mask(rem)
push!(unrolled_loop_body, :(xs[byte_idx + increment * $(q - 1) + $div] &= $bit))
end

# Increment the byte index to where the next / 9th multiple is located
push!(unrolled_loop_body, :(byte_idx += increment * 30 + $p))

quote
while true
$(unrolled_loop_body...)
end
end
end

"""
The fan-in / fan-out phase that crosses off one multiple and then checks bounds; this is
before and after the unrolled loop starts and finishes respectively.
"""
function single_loop_item_not_unrolled(p_idx, q_idx, save_on_exit = true)
# Our prime number modulo 30
p = ps[p_idx]

ps_next = (1, 7, 11, 13, 17, 19, 23, 29, 31)

# Label name
jump_idx = 8 * (p_idx - 1) + q_idx

# Current and next multiplier modulo 30
q_curr, q_next = ps_next[q_idx], ps_next[q_idx + 1]

# Get the bit mask for crossing off p * q_curr
div_curr, rem_curr = divrem(p * q_curr, 30)
bit = wheel_mask(rem_curr)

# Compute the increments for the byte index for the next multiple
incr_bytes = p * q_next ÷ 30 - div_curr
incr_multiple = q_next - q_curr

quote
# Todo: this generates an extra jump, maybe conditional moves are possible?
if byte_idx > n_bytes

# For a segmented sieve we store where we exit the loop, since that is the
# entrypoint in the next loop; to avoid modulo computation to find the offset
$(save_on_exit ? :(last_idx = $jump_idx) : nothing)
@goto out
end

# Cross off the multiple
xs[byte_idx] &= $bit

# Increment the byte index to where the next multiple is located
byte_idx += increment * $incr_multiple + $incr_bytes
end
end

"""
Full loop generates a potentially unrolled loop for a particular wheel
that may or may not save the exit point.
"""
function full_loop_for_wheel(wheel, unroll = true, save_on_exit = true)
loop_statements = []

for i = 1 : 8
push!(loop_statements, create_label(8 * (wheel - 1) + i))
unroll && i == 1 && push!(loop_statements, unrolled_loop(wheel))
push!(loop_statements, single_loop_item_not_unrolled(wheel, i, save_on_exit))
end

quote
while true
$(loop_statements...)
end
end
end

"""
Generates a sieving loop that crosses off multiples of a given prime number.

@sieve_loop :unroll :save_on_exit
@sieve_loop
"""
macro sieve_loop(options...)
unroll, save_on_exit = :(:unroll) in options, :(:save_on_exit) in options

# When crossing off p * q where `p` is the siever prime and `q` the current multiplier
# we have that p and q are {1, 7, 11, 13, 17, 19, 23, 29} mod 30.
# For each of these 8 possibilities for `p` we create a loop, and per loop we
# create 8 entrypoints to jump into. The first entrypoint is the unrolled loop for
# whenever we can remove 8 multiples at the same time when all 8 fit in the interval
# between byte_start:byte_next_start-1. Otherwise we can only remove one multiple at
# a time. With 8 loops and 8 entrypoints per loop we have 64 different labels, numbered
# x1 ... x64.

# As an example, take p = 7 as a prime number and q = 23 as the first multiplier, and
# assume our number line starts at 1 (so byte 1 represents 1:30, byte 2 represent 31:60).
# We have to cross off 7 * 23 = 161 first, which has byte index 6. Our prime number `p`
# is in the 2nd spoke of the wheel and q is in the 7th spoke. This means we have to jump
# to the 7th label in the 2nd loop; that is label 8 * (2 - 1) + 7 = 15. There we cross
# off the multiple (since 161 % 30 = 11 is the 3rd spoke, we "and" the byte with 0b11011111)
# Then we move to 7 * 29 (increment the byte index accordingly), cross it off as well.
# And now we enter the unrolled loop where 7 * {31, 37, ..., 59} are crossed off, then
# 7 * {61, 67, ..., 89} etc. Lastly we reach the end of the sieving interval, we cross
# off the remaining multiples one by one, until the byte index is passed the end.
# When that is the case, we save at which multiple / label we exited, so we can jump
# there without computation when the next interval of the number line is sieved.

esc(quote
$(unroll ? :(unrolled_max = n_bytes - increment * 28 - 28) : nothing)

# Create jumps inside loops
$([create_jump(:wheel_idx, i) for i = 1 : 64]...)

# # Create loops
$([full_loop_for_wheel(wheel, unroll, save_on_exit) for wheel in 1 : 8]...)

# Point of exit
@label out
end)
end
62 changes: 62 additions & 0 deletions src/segmented_sieve/presieve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
The {2, 3, 5}-wheel is efficient because it compresses memory
perfectly (1 byte = 30 numbers) and removes 4/15th of all the
multiples already. We don't get that memory efficiency when
extending the wheel to {2, 3, 5, 7}, since we need to store
48 bits per 210 numbers, which is could be done with one 64-bit
integer per 210 numbers, which in fact compresses worse
(64 bits / 210 numbers) than the {2, 3, 5}-wheel
(8 bits / 30 numbers).

What we can do however, is compute the repeating pattern that
the first `n` primes create, and copy that pattern over. That
is, we look at a the numbers modulo p₁ * p₂ * ⋯ * pₙ.

For instance, when presieving all multiples of {2, 3, ..., 19}
we allocate a buffer for the range 1 : 2 * 3 * ... * 19 =
1:9_699_690. In a {2, 3, 5} wheel this means a buffer of
9_699_690 ÷ 30 = 323_323 bytes.
"""
function create_presieve_buffer()
n_bytes = 7 * 11 * 13 * 17
xs = fill(0xFF, n_bytes)

@inbounds for p in (7, 11, 13, 17)
p² = p * p
byte_idx = p² ÷ 30 + 1
wheel = to_idx(p)
wheel_idx = 8 * (wheel - 1) + wheel
increment = 0
@sieve_loop :unroll
end

@inbounds xs[1] = 0b11100001 # remove 7, 11, 13 and 17
return xs
end

"""
When applying the presieve buffer, we have to compute the offset in
"""
function apply_presieve_buffer!(xs::Vector{UInt8}, buffer::Vector{UInt8}, byte_start, byte_stop)

len = byte_stop - byte_start + 1

# todo, clean this up a bit.
from_idx = (byte_start - 1) % length(buffer) + 1
to = min(len, length(buffer) - from_idx + 1)

# First copy the remainder of buffer at the front
copyto!(view(xs, Base.OneTo(to)), view(buffer, from_idx:from_idx + to - 1))
from = to + 1

# Then copy buffer multiple times
while from + length(buffer) - 1 <= len
copyto!(view(xs, from : from + length(buffer) - 1), buffer)
from += length(buffer)
end

# And finally copy the remainder of buffer again
copyto!(view(xs, from:len), view(buffer, Base.OneTo(length(from:len))))

xs
end
124 changes: 124 additions & 0 deletions src/segmented_sieve/sieve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import Base: iterate

export SegmentedSieve

function generate_siever_primes(small_sieve::SmallSieve, segment_lo)
xs = small_sieve.xs
sievers = Vector{Siever}(undef, vec_count_ones(xs))
j = 0
@inbounds for i = eachindex(xs)
x = xs[i]
while x != 0x00
sievers[j += 1] = Siever(compute_prime(x, i), segment_lo)
x &= x - 0x01
end
end
return sievers
end

struct SegmentIterator{T<:AbstractUnitRange}
range::T
segment_length::Int
first_byte::Int
last_byte::Int
sievers::Vector{Siever}
presieve_buffer::Vector{UInt8}
segment::Vector{UInt8}
end

struct Segment{Tr,Ts}
range::Tr
segment::Ts
end

function Base.show(io::IO, s::Segment)
# compute left padding
padding = floor(Int, log10(last(s.range))) + 1

padding_str = " " ^ padding

print(io, padding_str, " ")
for p in ps
print(io, lpad(p, 2, "0"), " ")
end

println()

for (start, byte) in zip(s.range, s.segment)
mask = 0b00000001
print(lpad(start, padding, "0"), " ")
for i = 1 : 8
print(io, (byte & mask) == mask ? " x " : " . ")
mask <<= 1
end
println()
end

io
end

function SegmentIterator(range::T, segment_length::Integer) where {T<:AbstractUnitRange}
from, to = first(range), last(range)
first_byte, last_byte = cld(first(range), 30), cld(last(range), 30)
sievers = generate_siever_primes(SmallSieve(isqrt(to)), 30 * (first_byte - 1) + 1)
presieve_buffer = create_presieve_buffer()
xs = zeros(UInt8, segment_length)

return SegmentIterator{T}(range, segment_length, first_byte, last_byte, sievers, presieve_buffer, xs)
end

function iterate(iter::SegmentIterator, segment_index_start = iter.first_byte)
@inbounds begin
if segment_index_start ≥ iter.last_byte
return nothing
end

from, to = first(iter.range), last(iter.range)

segment_index_next = min(segment_index_start + iter.segment_length, iter.last_byte + 1)
segment_curr_len = segment_index_next - segment_index_start

# Presieve
apply_presieve_buffer!(iter.segment, iter.presieve_buffer, segment_index_start, segment_index_next - 1)

# Set the preceding so many bits before `from` to 0
if segment_index_start == iter.first_byte
if iter.first_byte === 1
iter.segment[1] = 0b11111110 # just make 1 not a prime.
end
for i = 1 : 8
30 * (segment_index_start - 1) + ps[i] >= from && break
iter.segment[1] &= wheel_mask(ps[i])
end
end

# Set the remaining so many bits after `to` to 0
if segment_index_next == iter.last_byte + 1
for i = 8 : -1 : 1
to ≥ 30 * (segment_index_next - 2) + ps[i] && break
iter.segment[segment_curr_len] &= wheel_mask(ps[i])
end
end

# Sieve the interval, but skip the pre-sieved primes
xs = iter.segment

for p_idx in 5:length(iter.sievers)
p = iter.sievers[p_idx]
last_idx = 0
n_bytes = segment_index_next - segment_index_start
byte_idx = p.byte_index - segment_index_start + 1
wheel_idx = p.wheel_index
increment = p.prime_div_30
@sieve_loop :unroll :save_on_exit
iter.sievers[p_idx] = Siever(increment, segment_index_start + byte_idx - 1, last_idx)
end

segment_start = 30 * (segment_index_start - 1)
segment_stop = 30 * (segment_index_next - 1) - 1

segment_index_start += iter.segment_length

return Segment(segment_start:30:segment_stop, view(xs, Base.OneTo(segment_curr_len))), segment_index_start
end
end
Loading