~ehmry/nim_sphincs

22ae4e427a46f4026879432c229a8ee668998d94 — Emery Hemingway 3 years ago
Initial commit
A  => .gitignore +6 -0
@@ 1,6 @@
*.html
nimcache
profile_results.txt
sphincs+-reference-implementation*
test
*.req

A  => Makefile +42 -0
@@ 1,42 @@
# Makefile for generating test data

REF_VERSION = 20180313
REF_PKG_DIR := sphincs+-reference-implementation-$(REF_VERSION)
CRYPTO_SIGN_DIR := $(REF_PKG_DIR)/crypto_sign

TEST_RSPS := $(foreach S,128f 128s,tests/shake256-$(S)/PQCsignKAT_64.rsp)
TEST_RSPS += $(foreach S,192f 192s,tests/shake256-$(S)/PQCsignKAT_96.rsp)
TEST_RSPS += $(foreach S,256f 256s,tests/shake256-$(S)/PQCsignKAT_128.rsp)

RNG_SOURCES = tests/rng.c tests/rng.h

default: $(TEST_RSPS) $(RNG_SOURCES)

tests/shake256-%/PQCsignKAT_64.rsp: \
		$(CRYPTO_SIGN_DIR)/sphincs-shake256-%/PQCgenKAT_sign
	mkdir -p $(@D)
	cd $(@D) && ../../$<

tests/shake256-%/PQCsignKAT_96.rsp: \
		$(CRYPTO_SIGN_DIR)/sphincs-shake256-%/PQCgenKAT_sign
	mkdir -p $(@D)
	cd $(@D) && ../../$<

tests/shake256-%/PQCsignKAT_128.rsp: \
		$(CRYPTO_SIGN_DIR)/sphincs-shake256-%/PQCgenKAT_sign
	mkdir -p $(@D)
	cd $(@D) && ../../$<

$(CRYPTO_SIGN_DIR)/sphincs-shake256-%/PQCgenKAT_sign: $(REF_PKG_DIR)
	make -C $(@D)

$(RNG_SOURCES): $(REF_PKG_DIR)
	cp $</crypto_sign/sphincs-haraka-128f/$(notdir $@) $@
	git add $@

$(REF_PKG_DIR): $(REF_PKG_DIR).tar.bz2
	tar xf $<
	touch $@

$(REF_PKG_DIR).tar.bz2:
	wget https://sphincs.org/data/$@

A  => README.md +39 -0
@@ 1,39 @@
SPHINCS⁺ is a stateless hash-based signature scheme that
has been submitted to the NIST post-quantum crypto project.

This library contains Nim implementations for the six SHAKE256
variants of SPHINCS⁺. Performance will be abysmal until the
Keccak implementation is optimized.

Each signature scheme is implemented as a seperate module.
Multiple scheme modules may be imported at once, the correct
procedures will be deduced from the keypair type.

A procedure for supplying random bytes must be provided
for key generation and non-deterministic signing. This
allows the compiler to enforce the `noSideEffect` pragma
for all exported procedures. This helps mitigate (or
implement) side-channels.

```
import sphincs/shake256_128s
import sphincs/shake256_256f

proc genEntropy(p: pointer; size: int) =
  ## Don't try this at home.
  zeroMem(p, size)

let
  pair1 = shake256_128s.generateKeypair(genEntropy)
  pair2 = shake256_256f.generateKeypair(genEntropy)
  sig1 = pair1.sign("foo", genEntropy)
  sig2 = pair2.sign("bar", genEntropy)
  (valid1, msg1) = pair1.pk.verify(sig1)
  (valid2, msg2) = pair2.pk.verify(sig2)
assert(valid1)
assert(msg2 == "bar")
```

Tests are available via the `nimble test` task and a makefile
is provided for recreating the test-vectors from the reference
implementation.

A  => doc/sphincs+-specification.pdf +0 -0
A  => sphincs.nimble +11 -0
@@ 1,11 @@
# Package

version       = "0.1.0"
author        = "Emery Hemingway"
description   = "SPHINCS⁺ stateless hash-based signature scheme"
license       = "MIT"
srcDir        = "src"

# Dependencies

requires "nim >= 0.17.3"

A  => src/sphincs/private/sha3.nim +406 -0
@@ 1,406 @@
{.push checks: off.}

import bitops

type
   SHA3* = object
      hash: array[25, uint64]
      buffer: array[168, uint8]
      buffer_idx: int
      max_idx: int
      rounds: int
      delim: uint8
      hash_size: int
   Kangaroo12* = object
      outer: SHA3
      inner: SHA3
      key: cstring
      key_size: int
      chunks: uint64
      current: uint64

   SHA3_HASH* = enum
      SHA3_224 = 224,
      SHA3_256 = 256,
      SHA3_384 = 384,
      SHA3_512 = 512
   SHA3_SHAKE* = enum
      SHA3_SHAKE128 = 128, 
      SHA3_SHAKE256 = 256

const 
   RC = [
      0x0000000000000001'u64, 0x0000000000008082'u64,
      0x800000000000808a'u64, 0x8000000080008000'u64,
      0x000000000000808b'u64, 0x0000000080000001'u64,
      0x8000000080008081'u64, 0x8000000000008009'u64,
      0x000000000000008a'u64, 0x0000000000000088'u64,
      0x0000000080008009'u64, 0x000000008000000a'u64,
      0x000000008000808b'u64, 0x800000000000008b'u64,
      0x8000000000008089'u64, 0x8000000000008003'u64,
      0x8000000000008002'u64, 0x8000000000000080'u64,
      0x000000000000800a'u64, 0x800000008000000a'u64,
      0x8000000080008081'u64, 0x8000000000008080'u64,
      0x0000000080000001'u64, 0x8000000080008008'u64 ]

iterator pilAndRotc(): (int, int) =
  yield (10, 1)
  yield (7, 3)
  yield (11, 6)
  yield (17, 10)
  yield (18, 15)
  yield (3, 21)
  yield (5, 28)
  yield (16, 36)
  yield (8, 45)
  yield (21, 55)
  yield (24, 2)
  yield (4, 14)
  yield (15, 27)
  yield (23, 41)
  yield (19, 56)
  yield (13, 8)
  yield (12, 25)
  yield (2, 43)
  yield (20, 62)
  yield (14, 18)
  yield (22, 39)
  yield (9, 61)
  yield (6, 20)
  yield (1, 44)

proc right_encode(n: uint64): seq[uint8] =
   var z = n
   result = @[]
   var i: uint8 = 0
   while (z > 0'u64):
      result.insert(uint8(`mod`(z, 256)), 0)
      inc(i)
      z = `div`(z, 256)
   result.add(i)

proc thera(h: var array[25, uint64]) {.inline.} =
   var
      a, b: array[5, uint64]
   for i in 0..<5:
      a[i] = h[i] xor h[i + 5] xor h[i + 10] xor
             h[i + 15] xor h[i + 20]
   b[0] = rotateLeftBits(a[1], 1) xor a[4]
   b[1] = rotateLeftBits(a[2], 1) xor a[0]
   b[2] = rotateLeftBits(a[3], 1) xor a[1]
   b[3] = rotateLeftBits(a[4], 1) xor a[2]
   b[4] = rotateLeftBits(a[0], 1) xor a[3]
   for i in 0..<5:
      h[i]      = h[i]      xor b[i]
      h[i + 5]  = h[i + 5]  xor b[i]
      h[i + 10] = h[i + 10] xor b[i]
      h[i + 15] = h[i + 15] xor b[i]
      h[i + 20] = h[i + 20] xor b[i]

proc rho_pi(h: var array[25, uint64]) {.inline.} =
   var a, b: uint64
   a = h[1]
   for p, r in pilAndRotc():
      b = h[p]
      h[p] = rotateLeftBits(a, r)
      a = b

proc chi(h: var array[25, uint64]) {.inline.} =
   var a, b: uint64
   for i in countup(0, 20, 5):
      a = h[i]
      b = h[i + 1]
      h[i]     = h[i]     xor not(b) and h[i + 2]
      h[i + 1] = h[i + 1] xor not(h[i + 2]) and h[i + 3]
      h[i + 2] = h[i + 2] xor not(h[i + 3]) and h[i + 4]
      h[i + 3] = h[i + 3] xor not(h[i + 4]) and a
      h[i + 4] = h[i + 4] xor not(a) and b

proc xor_buffer(c: var SHA3) {.inline.} =
   for i in 0..<c.max_idx div 8:
      c.hash[i] = c.hash[i] xor 
                  cast[ptr uint64](addr(c.buffer[i*8]))[]
 
proc keccakf(c: var SHA3) =
   var y = 0
   if (c.rounds == 12): y = c.rounds
   for i in 0..<c.rounds:
      thera(c.hash)
      rho_pi(c.hash)
      chi(c.hash)
      c.hash[0] = c.hash[0] xor RC[i+y]
   c.buffer_idx = 0

proc sha3_update*(c: var SHA3, data: cstring|string|seq|uint8,
                  data_size: int) =
   for i in 0..<data_size:
      when data is cstring or data is string:
         c.buffer[c.buffer_idx] = data[i].uint8
      elif data is seq:
         c.buffer[c.buffer_idx] = data[i]
      else:
         c.buffer[c.buffer_idx] = data
      inc(c.buffer_idx)
      if (c.buffer_idx >= c.max_idx):
         xor_buffer(c)
         keccakf(c)

proc sha3_update*(c: var SHA3, data: openArray[byte]) =
  for i in 0..data.high:
    c.buffer[c.buffer_idx] = data[i]
    inc(c.buffer_idx)
    if (c.buffer_idx >= c.max_idx):
      xor_buffer(c)
      keccakf(c)

template sha3_init*(c: var SHA3, hash: typed, size: int = 0) =
   let rate = ord(hash) div 8
   if (size == 0):
      c.hash_size = rate
   else:
      c.hash_size = size
   if (hash is SHA3_HASH):
      assert(c.hash_size >= 1 and c.hash_size <= rate)
   c.rounds = 24
   c.max_idx = 200 - 2 * rate
   if (hash is SHA3_SHAKE):
      c.delim = 31
   else: 
      c.delim = 6

proc sha3_final*(c: var SHA3; result: var openArray[byte])=
   assert(result.len == c.hash_size)
   c.buffer[c.buffer_idx] = c.delim
   inc(c.buffer_idx)
   for i in c.buffer_idx..<c.max_idx: c.buffer[i] = 0
   c.buffer[c.max_idx - 1] = c.buffer[c.max_idx - 1] xor 128
   xor_buffer(c)
   keccakf(c)
   var j: int
   while (c.hash_size > 0):
      let block_size = min(c.hash_size, c.max_idx)
      for i in 0..<block_size:
         result[j] = (cast[uint8]((c.hash[i div 8] shr 
                    (8 * (i and 7)) and 0xFF)))
         inc j
      dec(c.hash_size, block_size)
      if (c.hash_size > 0): keccakf(c)
   zeroMem(addr(c), sizeof(c))
   assert(j == result.len)

proc sha3_final*(c: var SHA3): seq[uint8] =
   newSeq(result, c.hash_size)
   sha3_final(c, result)

proc sha3_init*(c: var Kangaroo12, size: int,
                key: cstring = nil, key_size: int = 0) =
   sha3_init(c.outer, SHA3_SHAKE128, size)
   sha3_init(c.inner, SHA3_SHAKE128, 32)
   c.outer.rounds = 12
   # c.outer.delim in sha3_final()
   c.inner.rounds = 12
   c.inner.delim = 11
   c.key = key
   c.key_size = key_size

proc sha3_update*(c: var Kangaroo12, data: cstring|string|seq|uint8,
                  data_size: int) =
   let P = @[3'u8, 0'u8, 0'u8, 0'u8, 0'u8, 0'u8, 0'u8, 0'u8]
   for i in 0..<data_size:
      if (c.current == 8192):
         if (c.chunks == 0):
            sha3_update(c.outer, P, 8)
         else:
            sha3_update(c.outer, sha3_final(c.inner), 32)
            sha3_init(c.inner, SHA3_SHAKE128, 32)
            c.inner.rounds = 12
            c.inner.delim = 11
         c.current = 0   
         inc(c.chunks)

      if (c.chunks == 0):
         when data is cstring or data is string:
            sha3_update(c.outer, ($data[i]).cstring, 1)
         elif data is seq:
            sha3_update(c.outer, data[i], 1)
         else:
            sha3_update(c.outer, data, 1)
      else:
         when data is cstring or data is string:
            sha3_update(c.inner, ($data[i]).cstring, 1)
         elif data is seq:
            sha3_update(c.inner, data[i], 1)
         else:
            sha3_update(c.inner, data, 1)

      inc(c.current)

proc sha3_final*(c: var Kangaroo12): seq[uint8] =
   let P = @[255'u8, 255'u8]
   var R: seq[uint8]
   result = @[]

   sha3_update(c, c.key, c.key_size)
   R = right_encode(uint64(c.key_size))
   sha3_update(c, R, len(R))

   if (c.chunks == 0):
      c.outer.delim = 7
   else:
      sha3_update(c.outer, sha3_final(c.inner), 32)
      R = right_encode(c.chunks)
      sha3_update(c.outer, R, len(R))
      sha3_update(c.outer, P, 2)
      c.outer.delim = 6
   result = sha3_final(c.outer)
   zeroMem(addr(c), sizeof(c))

proc `$`*(d: seq[uint8]): string =
  const digits = "0123456789abcdef"
  result = ""
  for i in 0..high(d):
    add(result, digits[(d[i].int shr 4) and 0xF])
    add(result, digits[d[i].int and 0xF])

template getSHA3*(h: typed, s: string, hash_size: int = 0): string =
   var ctx: SHA3
   sha3_init(ctx, h, hash_size)
   sha3_update(ctx, s, len(s))
   $sha3_final(ctx)

proc getSHA3*(s: string, hash_size: int = 0): string =
   var ctx: Kangaroo12
   sha3_init(ctx, hash_size)
   sha3_update(ctx, s, len(s))
   $sha3_final(ctx)

when isMainModule:
   import strutils

   var
      ktx: Kangaroo12
      f: File
      splt: seq[string]
      Aaa: cstring

   proc hex2str(s: string): string =
      result = ""
      for i in countup(0, high(s), 2):
         add(result, chr(parseHexInt(s[i] & s[i+1])))
   f = open("kangaroo-K12.rsp", fmRead)
   while true:
      try:
         sha3_init(ktx, 16, "abc", 3)
         splt = split(f.readLine(), ':')
         Aaa = cstring(repeatStr(parseInt(splt[0]), "abc"))
         for i in 0..high(Aaa):
            sha3_update(ktx, ($Aaa[i]).cstring, 1)
         assert($sha3_final(ktx) == toLowerAscii(splt[1]))
      except IOError: break
   close(f)
   assert(getSHA3("a", 16) == "9ead6b5332e658d12672d3ab0de17f12")
   assert(getSHA3(nil, 16) == "1ac2d450fc3b4205d19da7bfca1b3751")
   assert(getSHA3("abc", 4) == "ab174f32")
   assert(getSHA3("The quick brown fox jumps over the lazy dog") == "b4f249b4f77c58df170aa4d1723db112")
   
   var ctx: SHA3

   sha3_init(ctx, SHA3_224)
   assert("6b4e03423667dbb73b6e15454f0eb1abd4597f9a1b078e3f5b5a6bc7" == 
           $sha3_final(ctx))
   sha3_init(ctx, SHA3_256)
   assert("a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a" ==
           $sha3_final(ctx))
   sha3_init(ctx, SHA3_384)
   assert("0c63a75b845e4f7d01107d852e4c2485c51a50aaaa94fc61995e71bbee983a2ac3713831264adb47fb6bd1e058d5f004" ==
           $sha3_final(ctx))
   sha3_init(ctx, SHA3_512)
   assert("a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a615b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26" ==
           $sha3_final(ctx)) 
   sha3_init(ctx, SHA3_SHAKE128)
   assert("7f9c2ba4e88f827d616045507605853e" ==
           $sha3_final(ctx))
   sha3_init(ctx, SHA3_SHAKE256)
   assert("46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762f" ==
           $sha3_final(ctx))

   sha3_init(ctx, SHA3_224, 4)
   sha3_update(ctx, "a", 1)
   sha3_update(ctx, "b", 1)
   sha3_update(ctx, "c", 1)
   assert(getSHA3(SHA3_224, "abc", 4) == $sha3_final(ctx))
   sha3_init(ctx, SHA3_256, 4)
   sha3_update(ctx, "a", 1)
   sha3_update(ctx, "b", 1)
   sha3_update(ctx, "c", 1)
   assert(getSHA3(SHA3_256, "abc", 4) == $sha3_final(ctx))
   sha3_init(ctx, SHA3_384, 4)
   sha3_update(ctx, "a", 1)
   sha3_update(ctx, "b", 1)
   sha3_update(ctx, "c", 1)
   assert(getSHA3(SHA3_384, "abc", 4) == $sha3_final(ctx))
   sha3_init(ctx, SHA3_512, 4)
   sha3_update(ctx, "a", 1)
   sha3_update(ctx, "b", 1)
   sha3_update(ctx, "c", 1)
   assert(getSHA3(SHA3_512, "abc", 4) == $sha3_final(ctx))
   assert("f4202e3c5852f9182a0430fd8144f0a7" ==
          getSHA3(SHA3_SHAKE128, "The quick brown fox jumps over the lazy dog"))
   assert("853f4538be0db9621a6cea659a06c110" ==
          getSHA3(SHA3_SHAKE128, "The quick brown fox jumps over the lazy dof"))
   assert("5881092dd818bf5cf8a3ddb793fbcba74097d5c526a6d35f97b83351940f2cc844c50af32acd3f2cdd066568706f509bc1bdde58295dae3f891a9a0fca5783789a41f8611214ce612394df286a62d1a2252aa94db9c538956c717dc2bed4f232a0294c857c730aa16067ac1062f1201fb0d377cfb9cde4c63599b27f3462bba4a0ed296c801f9ff7f57302bb3076ee145f97a32ae68e76ab66c48d51675bd49acc29082f5647584e6aa01b3f5af057805f973ff8ecb8b226ac32ada6f01c1fcd4818cb006aa5b4cdb3611eb1e533c8964cacfdf31012cd3fb744d02225b9" ==
          getSHA3(SHA3_SHAKE128, "abc", 222))

   var
      data, hash: string
      tst = newSeq[tuple[h: SHA3_HASH, f: string]]()
      shake = newSeq[tuple[h: SHA3_SHAKE, f:string]]()
   tst.add((h:SHA3_224, f:"SHA3_224ShortMsg.rsp"))
   tst.add((h:SHA3_224, f:"SHA3_224LongMsg.rsp"))
   tst.add((h:SHA3_256, f:"SHA3_256ShortMsg.rsp"))
   tst.add((h:SHA3_256, f:"SHA3_256LongMsg.rsp"))
   tst.add((h:SHA3_384, f:"SHA3_384ShortMsg.rsp"))
   tst.add((h:SHA3_384, f:"SHA3_384LongMsg.rsp"))
   tst.add((h:SHA3_512, f:"SHA3_512ShortMsg.rsp"))
   tst.add((h:SHA3_512, f:"SHA3_512LongMsg.rsp"))
   shake.add((h:SHA3_SHAKE128, f:"SHAKE128ShortMsg.rsp"))
   shake.add((h:SHA3_SHAKE128, f:"SHAKE128LongMsg.rsp"))
   shake.add((h:SHA3_SHAKE256, f:"SHAKE256ShortMsg.rsp"))
   shake.add((h:SHA3_SHAKE256, f:"SHAKE256LongMsg.rsp"))
   for t in tst:
      f = open(t[1], fmRead)
      while true:
         try:
            discard f.readLine()
            data = f.readLine()
            data = hex2str(data[6..^0])
            hash = f.readLine()
            hash = hash[5..^0]
            assert(getSHA3(t[0], data) == hash)
            sha3_init(ctx, t[0])
            for i in 0..high(data):
               sha3_update(ctx, ($data[i]).cstring, 1)
            assert($sha3_final(ctx) == hash)
            discard f.readLine()
         except IOError: break
      close(f)
   for t in shake:
      f = open(t[1], fmRead)
      while true:
         try:
            discard f.readLine()
            data = f.readLine()
            data = hex2str(data[6..^0])
            hash = f.readLine()
            hash = hash[9..^0]
            assert(getSHA3(t[0], data) == hash)
            sha3_init(ctx, t[0])
            for i in 0..high(data):
               sha3_update(ctx, ($data[i]).cstring, 1)
            assert($sha3_final(ctx) == hash)
            discard f.readLine()
         except IOError: break
      close(f)
   echo "ok"

{.pop.}

A  => src/sphincs/private/sphincs_shake256.nim +136 -0
@@ 1,136 @@
{.push checks: off.} # painfully slow

include ./sphincsbase.nim

import sha3

proc sha3_update(ctx: var SHA3; adrs: Address) =
  var buf: array[4, byte]
  for w in adrs.items:
    w.toByte(buf)
    sha3_update(ctx, buf)

proc F(pk: PK; adrs: Address; M1: Nbytes): Nbytes =
  var ctx: SHA3
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_final(ctx, result)
    # create a bitmask in result space
  for i in 0..<n:
    result[i] = result[i] xor M1[i]
    # apply bitmask to message
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_update(ctx, result)
    # hash again with bitmasked message
  sha3_final(ctx, result)

proc H(pk: PK; adrs: Address; M1, M2: Nbytes): Nbytes =
  var
    bitmasked: array[2*n, byte]
    ctx: SHA3
  sha3_init(ctx, SHA3_SHAKE256, n*2)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_final(ctx, bitmasked)
    # create bitmask
  for i in 0..<n:
    bitmasked[i] = bitmasked[i] xor M1[i]
  for i in n..<n*2:
    bitmasked[i] = bitmasked[i] xor M2[i-n]
    # apply bitmask to messages
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_update(ctx, bitmasked)
    # hash again with bitmasked message
  sha3_final(ctx, result)

proc T_k(pk: PK; adrs: Address; M: array[k, Nbytes]): Nbytes =
  var
    bitmasked: array[M.len*n, byte]
    ctx: SHA3
  sha3_init(ctx, SHA3_SHAKE256, M.len*n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_final(ctx, bitmasked)
    # create bitmask
  var off: int
  for i in 0..<M.len:
    for j in 0..<n:
      bitmasked[off] = bitmasked[off] xor M[i][j]
      inc off
    # apply bitmask to messages
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_update(ctx, bitmasked)
    # hash again with bitmasked message
  sha3_final(ctx, result)

proc T_len(pk: PK; adrs: Address; M: array[wotsLen, Nbytes]): Nbytes =
  var
    bitmasked: array[M.len*n, byte]
    ctx: SHA3
  sha3_init(ctx, SHA3_SHAKE256, M.len*n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_final(ctx, bitmasked)
    # create bitmask
  var off: int
  for i in 0..<M.len:
    for j in 0..<n:
      bitmasked[off] = bitmasked[off] xor M[i][j]
      inc off
    # apply bitmask to messages
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, adrs)
  sha3_update(ctx, bitmasked)
    # hash again with bitmasked message
  sha3_final(ctx, result)

proc PRFmsg(sk: SK; optRand: Nbytes; M: string|openArray[byte]): Nbytes =
  ## Pseudorandom function to generate randomness for message compression.
  var ctx: SHA3
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, sk.prf)
  sha3_update(ctx, optRand)
  sha3_update(ctx, M, M.len)
  sha3_final(ctx, result)

proc Hmsg(R: Nbytes; pk: PK; M: string|openArray[byte]): (array[partialDigestBytes,byte], int64, int32) =
  ## Keyed hash funcion for compression messages to be signed.
  var
    digest: array[m, byte]
    ctx: SHA3
  # TODO: the reference implementation does this wrong,
  # they overrun and hash some extra memory
  sha3_init(ctx, SHA3_SHAKE256, m)
  sha3_update(ctx, R)
  sha3_update(ctx, pk.seed)
  sha3_update(ctx, pk.root)
  sha3_update(ctx, M, M.len)
  sha3_final(ctx, digest)

  copyMem(result[0].addr, digest.addr, partialDigestBytes)

  # XXX
  # take the last bits from these regions, not the first bits as the spec might decribe
  bigEndian64(result[1].addr, digest[digest.len-8-leafIndexBytes].addr)
  when h - h div d < 64:
    result[1] = result[1] and (not(int64.high shl (h - h div d)))
  bigEndian32(result[2].addr, digest[digest.len-4].addr)
  result[2] = result[2] and (not(int32.high shl (h div d)))

proc PRF(sk: SK; adrs: Address): Nbytes =
  ## Pseudorandom function for key generation.
  var ctx: SHA3
  sha3_init(ctx, SHA3_SHAKE256, n)
  sha3_update(ctx, sk.seed)
  sha3_update(ctx, adrs)
  sha3_final(ctx, result)

include ./sphincsinstantiate.nim

A  => src/sphincs/private/sphincsbase.nim +147 -0
@@ 1,147 @@
import endians, math

proc lg(x: Natural): int =
  ## For x a non-negative real number returns the logarithm to base 2 of x.
  var x = x
  while x > 1:
    inc result
    x = x shr 1

proc toByte(x: Natural; result: var openArray[byte]) =
  ## 2.4. Integer to Byte Conversion
  ## For x and y non-negative integers, we define Z = toByte(x, y) to be the y-byte string
  ## containing the binary representation of x in big-endian byte-order.
  for i in countdown(result.high, 0):
    result[result.high-i] = byte((x shr (8*i)) and 0xff)

proc base_w(result: var openArray[int]; x: openarray[byte]; w: Natural) =
  ## 2.5. Strings of Base-w Numbers (Function base_w)
  ## A byte string can be considered as a string of base w numbers, i.e. integers in the set {0, . . . , w−
  ## 1}. The correspondence is defined by the function base_w(X, w, out_len) as follows. Let X be
  ## a len_X- byte string, and w is an element of the set {4, 16, 256}, then base_w(X, w, out_len)
  ## outputs an array of out_len integers between 0 and w − 1 (Figure 1). The length out_len is
  ## REQUIRED to be less than or equal to 8 ∗ len_X/ log(w).
  assert(w in {4, 16, 256})
  assert(result.len <= (8 * x.len div lg(w)))
  var
    bytesIn, bytesOut, bits: int
    total: uint
  for i in 0..result.high:
    if bits == 0:
      total = (uint)x[bytesIn]
      inc bytesIn
      bits.inc 8
    bits.dec lg(w)
    result[bytesOut] = (int)(total shr bits) and (w-1)
    inc bytesOut

proc bits(M: openArray[byte]; start, count: int): int64 =
  let
    A = start
    B = start + count
  for i in (A div 8)..(B div 8):
    result = (result shl 8) or M[i].int64
  let b = B mod 8
  if b != 0: result = result shr (8-b)
  let mask = not (result.high shl (B-A))
  result = result and mask

type
  AddressWord = uint32
  Address = array[8, AddressWord]
    ## SPHINCS⁺ tree address
  AddressType = enum
    WOTS_HASH = 0, WOTS_PK = 1, TREE = 2, FORS_TREE = 3, FORS_ROOTS = 4.AddressWord

proc initAdrs(t: AddressType): Address =
  result[4] = t.AddressWord

proc getLayerAddress(a: Address): int = a[0].int
proc setLayerAddress(a: var Address; i: int) = a[0] = i.AddressWord

proc getTreeAddress(a: Address): int = a[3].int
proc setTreeAddress(a: var Address; i: int64) =
  a[2] = (AddressWord)i shr 32
  a[3] = (AddressWord)i

proc getType(adrs: Address): AddressType = adrs[4].AddressType

proc setType(adrs: var Address; t: AddressType) =
  ## Change the type word of an address.
  adrs[4] = t.AddressWord
  adrs[5] = 0
  adrs[6] = 0
  adrs[7] = 0

proc setChainAddress(a: var Address; address: int) =
  assert(a.getType == WOTS_HASH)
  a[6] = (AddressWord)address

proc getKeyPairAddress(a: Address): int =
  assert(a.getType != TREE)
  a[5].int

proc setKeyPairAddress(a: var Address; keyPair: SomeInteger) =
  assert(a.getType != TREE)
  a[5] = (AddressWord)keyPair

proc setHashAddress(a: var Address; i: int) =
  assert(a.getType == WOTS_HASH)
  a[7] = i.AddressWord

proc getTreeHeight(a: Address): int =
  assert(a.getType() in {TREE, FORS_TREE})
  a[6].int

proc setTreeHeight(a: var Address; i : int) =
  assert(a.getType() in {TREE, FORS_TREE})
  a[6] = i.AddressWord

proc getTreeIndex(a: var Address): int =
  assert(a.getType() in {TREE, FORS_TREE})
  a[7].int

proc setTreeIndex(a: var Address; i: int) =
  assert(a.getType() in {TREE, FORS_TREE})
  a[7] = i.AddressWord

proc copySubTree(x: var Address; y: Address) =
  for i in 0..3:
    x[i] = y[i]

proc copyKeyPair(x: var Address; y: Address) =
  for i in 0..3:
    x[i] = y[i]
  x[5] = y[5]

type
  Nbytes* = array[n, byte]

  SK* = object {.packed.}
    seed*: Nbytes
    prf*: Nbytes
      ## Secret key

  PK* = object {.packed.}
    seed*: Nbytes
    root*: Nbytes
      ## Public key

  KeyPair* = object {.packed.}
    sk*: SK
    pk*: PK
      ## Secret and public Keypair.
      ## Both keys need to be retained for signing,
      ## the public key is not fully derived from the secret.

const
  wotsLen1 = (int)ceil(8*n / lg(w))
  wotsLen2 = (int)floor(lg(wotsLen1*(w-1)) / lg(w)) + 1
  wotsLen = wotsLen1+wotsLen2

  partialDigestBytes = (int)floor((k*a + 7) / 8)
  treeIndexBytes = (int)floor((h - h/d + 7) / 8)
  leafIndexBytes = (int)floor((h/d + 7) / 8)

  m = partialDigestBytes + treeIndexBytes + leafIndexBytes
    ## Output length of Hmsg in bytes.

A  => src/sphincs/private/sphincsinstantiate.nim +406 -0
@@ 1,406 @@
#
# WOTS + One-Time Signatures
#

proc chain(input: Nbytes, i, s: int; pk: PK; adrs: Address): Nbytes =
  ## Compute an iteration of F on a n-byte input using
  ## a WOTS+ hash address and a public seed.
  var adrs = adrs
  assert((i+s) <= (w-1))
  result = input
  for j in i..<min(i+s, w):
    adrs.setHashAddress(j)
    result = F(pk, adrs, result)

proc wots_SKgen(sk: SK; adrs: Address): Nbytes =
  var adrs = adrs
  adrs.setHashAddress(0)
  PRF(sk, adrs)

proc wots_PKgen(sk: Sk; pk: PK; adrs: Address): array[wotsLen, Nbytes] =
  ## Generate a WOTS+ public key.
  var adrs = adrs
  for i in 0..<wotsLen:
    adrs.setChainAddress(i)
    let sk = wots_SKgen(sk, adrs)
    result[i] = chain(sk, 0, w - 1, pk, adrs)

proc wots_checksum(lengths: openArray[int]): array[wotsLen2, int] =
  var csum = 0

  for i in 0..<wotsLen1:
    csum = csum + w - 1 - lengths[i]
    # compute checksum

  csum = csum shl (8 - ((wotsLen2 * lg(w) ) mod 8))
    # convert csum to base w

  const len2Bytes = (int)ceil( float( wotsLen2 * lg(w) ) / 8)
  var b: array[len2Bytes, byte]
  csum.toByte(b)
  base_w(result, b, w)

proc wots_sign(M: Nbytes; sk: SK; pk: PK; adrs: Address): array[wotsLen, Nbytes] =
  ## Generate a WOTS+ signature on message M.
  var
    adrs = adrs
    lengths: array[wotsLen1, int]
  base_w(lengths, M, w)
    # convert message to base w

  for i in 0..<wotsLen1:
    adrs.setChainAddress(i)
    let prf = PRF(sk, adrs)
    result[i] = chain(prf, 0, lengths[i], pk, adrs)

  let csum = lengths.wots_checksum
  for i in wotsLen1..<wotsLen:
    adrs.setChainAddress(i)
    let prf = PRF(sk, adrs)
    result[i] = chain(prf, 0, csum[i-wotsLen1], pk, adrs)

proc wots_pkFromSig(sig: array[wotsLen,Nbytes]; M: Nbytes; pk: PK; adrs: var Address): array[wotsLen, Nbytes] =
  ## Compute a WOTS+ public key from a message and its signature.
  var lengths: array[wotsLen1, int]
  base_w(lengths, M, w)
    # convert message to base w

  for i in 0..<wotsLen1:
    adrs.setChainAddress(i)
    result[i] = chain(sig[i], lengths[i], w - 1 - lengths[i], pk, adrs)

  let csum = lengths.wots_checksum
  for i in wotsLen1..<wotsLen:
    adrs.setChainAddress(i)
    result[i] = chain(sig[i], csum[i-wotsLen1], w - 1 - csum[i-wotsLen1], pk, adrs)

#
# The SPHINCS + Hypertree
#

type GenLeafProc = proc(sk: SK; pk: PK; idx: int; adrs: Address): Nbytes

proc isOdd(x: int): bool {.inline.} = bool(x and 1)

type
  TreeNode = tuple[node: Nbytes; height: int]

proc treeHash(root: var Nbytes; authPath: var openArray[Nbytes];
              stack: var openArray[TreeNode];
              sk: SK; pk: PK, leafIdx, idxOffset: int;
              genLeaf: GenLeafProc; adrs: Address) =
  ## For a given leaf index, computes the authentication path and
  ## the resulting root node using Merkle's TreeHash algorithm.
  var
    treeAdrs = adrs
    offset, idx, treeIdx: int

  for idx in 0..<(1 shl (stack.len-1)):
    stack[offset].node = genLeaf(sk, pk, idx+idxOffset, treeAdrs)
      # Add the next leaf node to the stack
    stack[offset].height = 0
    inc offset
    if (leafIdx xor 0x1) == idx:
      # if this is a node we need it for the auth path
      authPath[0] = stack[offset-1].node
    while (1 < offset) and (stack[offset-1].height == stack[offset-2].height):
      # while the top-most nodes are of equal height...
      treeIdx = idx shr (stack[offset-1].height+1)
        # compute index of the new node, in the new layer

      treeAdrs.setTreeHeight(stack[offset-1].height + 1)
      treeAdrs.setTreeIndex(treeIdx + (idxOffset shr (stack[offset-1].height + 1)))
        # set the address of the node we're creating
      stack[offset-2].node = H(pk, treeAdrs, stack[offset-2].node, stack[offset-1].node)
        # hash the top-most nodes from the stack together
      dec offset
      inc stack[offset-1].height
        # note that the top-most node is now one layer higher
      if ((leafIdx shr stack[offset-1].height) xor 0x1) == treeIdx:
        # if this is a node we need for the auth path...
        authPath[stack[offset - 1].height] = stack[offset-1].node
  root = stack[0].node

proc wotsGenLeaf(sk: SK; pk: PK; adrsIdx: int; treeAdrs: Address): Nbytes =
  ## Computes the leaf at a given address. First generates the WOTS
  ## key pair, then computes leaf by hashing horizontally.
  var
    wotsAdrs = initAdrs(WOTS_HASH)
    wotsPkAdrs = initAdrs(WOTS_PK)
  wotsAdrs.copySubTree(treeAdrs)
  wotsAdrs.setKeyPairAddress(adrsIdx)
  wotsPkAdrs.copyKeyPair(wotsAdrs)
  let wotsPk = wots_PKgen(sk, pk, wotsAdrs)
  T_len(pk, wotsPkAdrs, wotsPk)

proc xmss_PKgen(sk: SK; pk: PK; adrs: Address): Nbytes =
  ## Generate an XMSS public key.
  # 4.1.4. XMSS Public Key Generation
  const height = h div d
  var auth: array[height, Nbytes]
    # not used, but `treeHash` computes both a root and authPath
  var treeStack: array[height+1, TreeNode]
  treeHash(result, auth, treeStack, sk, pk, 0, 0, wotsGenLeaf, adrs)

type XmssSignature = object {.packed.}
  # 4.1.5. XMSS Signature
  sig: array[wotsLen, Nbytes]
  auth: array[h div d, Nbytes]

# HT: The Hypertee (sic)

proc ht_PKgen(sk: SK; pk: PK): Nbytes =
  ## Generate an HT public key.
  var adrs = initAdrs(TREE)
  adrs.setLayerAddress(d-1)
  xmss_PKgen(sk, pk, adrs)

type HtSignature = array[d, XmssSignature]

#
# FORS: Forest Of Random Subsets
#

proc computeRoot(leaf: Nbytes; leafIdx, idxOffset: int;
                 authPath: openArray[Nbytes],
                 height: int;
                 pk: PK; adrs: Address): Nbytes =
  var
    nodes: (Nbytes, Nbytes)
    adrs = adrs
    leafIdx = leafIdx
    idxOffset = idxOffset

  if leafIdx.isOdd:
    # If leafIdx is odd, current path element is a right child
    # and authPath has to go left. 
    nodes[1] = leaf
    nodes[0] = authPath[0]
  else:
    # Otherwise it is the other way around.
    nodes[0] = leaf
    nodes[1] = authPath[0]

  for i in 0..(height-2):
    leafIdx = leafIdx shr 1
    idxOffset = idxOffset shr 1
    adrs.setTreeHeight(i+1)
    adrs.setTreeIndex(leafIdx+idxOffset)
    if leafIdx.isOdd:
      nodes[1] = H(pk, adrs, nodes[0], nodes[1])
      nodes[0] = authPath[i+1]
    else:
      nodes[0] = H(pk, adrs, nodes[0], nodes[1])
      nodes[1] = authPath[i+1]
  leafIdx = leafIdx shr 1
  idxOffset = idxOffset shr 1
  adrs.setTreeHeight(height)
  adrs.setTreeIndex(leafIdx+idxOffset)
  H(pk, adrs, nodes[0], nodes[1])
    # the last iteration is exceptional; we do not copy an authPath node

const
  forsHeight = a
  forsTrees = k
  forsMsgBytes = (forsHeight*forsTrees+7) div 8

proc fors_SKgen(sk: SK; adrs: Address): Nbytes =
  ## Compute a FORS private key value
  # 5.2. FORS Private Key
  PRF(sk, adrs)

proc messageIndices(msg: openArray[byte]): array[forsTrees, int] =
  #assert(msg.len > forsHeight*forsTrees div 8)
  var offset: int
  for i in 0..<forsTrees:
    for _ in 1..forsHeight:
      result[i] =  (result[i] shl 1) xor ((msg[offset shr 3].int shr (offset and 0x7)) and 0x1)
      inc offset

type ForsSignature = array[k, tuple[
  key: Nbytes,
  auth: array[a, Nbytes]]]

proc fors_SKtoLeaf(pk: PK, adrs: var Address; sk: Nbytes): Nbytes =
  F(pk, adrs, sk)

proc forsGenLeaf(sk: SK; pk: PK; addrIdx: int; adrs: Address): Nbytes =
  ## Procedure for generating leaves of FORS tree.
  var forsLeafAdrs = adrs
  forsLeafAdrs.setType(FORS_TREE)
  forsLeafAdrs.setTreeIndex(addrIdx)
  forsLeafAdrs.setKeyPairAddress(adrs.getKeyPairAddress)
  result = fors_SKtoLeaf(pk, forsLeafAdrs, fors_SKgen(sk, forsLeafAdrs))

proc forsSign(sig: var ForsSignature; public: var Nbytes; msg: openArray[byte]; sk: SK; pk: PK; adrs: Address) =
  ## Generate a FORS signature and public key on n-byte string M.
  let indices = messageIndices msg
  var
    roots: array[forsTrees, Nbytes]
    forsTreeAdrs = adrs
    forsPkAdrs = adrs
  forsTreeAdrs.setType(FORS_TREE)
  forsTreeAdrs.setKeyPairAddress(adrs.getKeyPairAddress)
  forsPkAdrs.setType(FORS_ROOTS)
  forsPkAdrs.setKeyPairAddress(adrs.getKeyPairAddress)

  for i in 0..<k:
    let idxOff = i * (1 shl forsHeight)
    forsTreeAdrs.setTreeHeight(0)
    forsTreeAdrs.setTreeIndex(indices[i] + idxOff)
    sig[i].key = fors_SKgen(sk, forsTreeAdrs)
    var treeStack: array[forsHeight+1, TreeNode]
    treeHash(roots[i], sig[i].auth, treeStack, sk, pk, indices[i], idxOff, forsGenLeaf, forsTreeAdrs)

  public = T_k(pk, forsPkAdrs, roots)
    # Hash horizontally across all tree roots to derive the public key.

proc fors_pkFromSig(sig: ForsSignature; msg: openArray[byte]; pk: PK; adrs: var Address): Nbytes =
  ## Compute a FORS public key from a FORS signature
  let indices = messageIndices msg
  var
    roots: array[forsTrees, Nbytes]
    forsTreeAdrs = adrs
    forsPkAdrs = adrs

  forsTreeAdrs.setType(FORS_TREE)
  forsTreeAdrs.setKeyPairAddress(adrs.getKeyPairAddress)

  forsPkAdrs.setType(FORS_ROOTS)
  forsPkAdrs.setKeyPairAddress(adrs.getKeyPairAddress)

  for i in 0..<forsTrees:
    let idxOff = i * (1 shl forsHeight)
    forsTreeAdrs.setTreeHeight(0)
    forsTreeAdrs.setTreeIndex(indices[i] + idxOff)

    let leaf = fors_SKtoLeaf(pk, forsTreeAdrs, sig[i].key)
      # derive the leaf from the included secret key part

    roots[i] = computeRoot(leaf, indices[i], idxOff, sig[i].auth, a, pk, forsTreeAdrs)
      # derive the corresponding root node of this tree

  T_k(pk, forsPkAdrs, roots)
    # Hash horizontally across all tree roots to derive the public key.

const
  spxTreeHeight = h div d
  signatureSize* = n + k*(n+a*n) + d*(wotsLen*n+(h div d)*n)
    ## Size of SPHINCS⁺ signture minus the message. The message
    ## is appended during signing so it should be no longer than a
    ## hash digest.

type
  SpxSignature = object {.packed.}
    R: Nbytes
    FORS: ForsSignature
    HT: HtSignature

  RandomBytes* = proc(buf: pointer; size: int)
    ## Procedure type for collecting entropy during
    ## key generation and signing. Please supply
    ## a procedure that writes `size` random bytes to `buf`.

{.pop.} # allow runtime checks

proc sign*(pair: KeyPair; M: string|openArray[byte]; optRand: Nbytes): string {.noSideEffect.} =
  ## Generate a SPHINCS⁺ signature.
  ## The signature will be deterministic unless `optRand` is randomized.
  let msgOff = sizeof(SpxSignature)
  result = newString(msgOff+M.len)

  let sig = cast[ptr SpxSignature](result[0].addr)
  sig.R = PRFmsg(pair.sk, optRand, M)
    # generate randomizer

  let (md, mTree, mLeaf) = Hmsg(sig.R, pair.pk, M)
  var
    root: Nbytes
    treeAdrs = initAdrs(TREE)
    wotsAdrs = initAdrs(WOTS_HASH)

  wotsAdrs.setTreeAddress(mTree)
  wotsAdrs.setKeyPairAddress(mLeaf)
  forsSign(sig.FORS, root, md, pair.sk, pair.pk, wotsAdrs)
    # FORS sign
  block:
    var
      idxTree = mTree
      idxLeaf = mLeaf
    for i in 0..<d:
      treeAdrs.setLayerAddress(i)
      treeAdrs.setTreeAddress(idxTree)
      wotsAdrs.copySubtree(treeAdrs)
      wotsAdrs.setKeypairAddress(idxLeaf)

      sig.HT[i].sig = wots_sign(root, pair.sk, pair.pk, wotsAdrs)
      var treeStack: array[spxTreeHeight+1, TreeNode]
      treeHash(root, sig.HT[i].auth, treeStack, pair.sk, pair.pk,
        idxLeaf, 0, wotsGenLeaf, treeAdrs)

      idxLeaf = (int32)idxTree and ((1 shl spxTreeHeight) - 1)
      idxTree = idxTree shr spxTreeHeight
        # update the indices for the next layer
  for i in 0..M.len:
    result[msgOff+i] = (char)M[i]
    # append signature with message

proc sign*(pair: KeyPair; M: string|openArray[byte]; rand: RandomBytes): string {.noSideEffect.} =
  ## Generate a SPHINCS⁺ signature. The passed `rand` procedure is used to
  ## create non-deterministic signatures which are generally recommended.
  var optRand: Nbytes
  rand(optRand.addr, n)
  pair.sign(M, optRand)

proc verify(pk: PK; sigStr: var string): (bool, string) {.noSideEffect.} =
  assert(sigStr.len > sizeof(SpxSignature))
  let
    sig = cast[ptr SpxSignature](sigStr[0].addr)
    M = sigStr[sizeof(SpxSignature)..sigStr.high]
  var
    root, leaf: Nbytes
    wotsAdrs = initAdrs(WOTS_HASH)
    treeAdrs = initAdrs(TREE)
    wotsPkAdrs = initAdrs(WOTS_PK)
    (md, idxTree, idxLeaf) = Hmsg(sig.R, pk, M)

  wotsAdrs.setTreeAddress(idxTree)
  wotsAdrs.setKeyPairAddress(idxLeaf)

  root = fors_pkFromSig(sig.FORS, md, pk, wotsAdrs)

  for i in 0..<d:
    # for each subtree
    treeAdrs.setLayerAddress(i)
    treeAdrs.setTreeAddress(idxTree)

    wotsAdrs.copySubtree(treeAdrs)
    wotsAdrs.setKeypairAddress(idxLeaf)
    wotsPkAdrs.copyKeyPair(wotsAdrs)

    let wotsPk = wots_pkFromSig(sig.HT[i].sig, root, pk, wotsAdrs)
    leaf = T_len(pk, wotsPkAdrs, wotsPk)
    root = computeRoot(leaf, idxLeaf, 0, sig.HT[i].auth, h div d, pk, treeAdrs)

    idxLeaf = (int32)idxTree and ((1 shl spxTreeHeight) - 1)
    idxTree = idxTree shr spxTreeHeight
      # update the indices for the next layer

  if root == pk.root:
    (true, M)
  else:
    (false, "")

proc verify*(pk: PK; sig: string): (bool, string) {.noSideEffect.} =
  ## Verify a SPHINCS⁺ signature.
  ## The signed message is assumed to be stored at the
  ## end of the signature string.
  var sig = sig
  pk.verify sig

proc generateKeypair*(seedProc: RandomBytes): KeyPair {.noSideEffect.} =
  ## Generate a SPHINCS⁺ key pair.
  seedProc(result.addr, n*3)
    # Randomize the seeds and PRF
  result.pk.root = ht_PKgen(result.sk, result.pk)
    # Compute root node of top-most subtree

A  => src/sphincs/shake256_128f.nim +13 -0
@@ 1,13 @@
##
## Fast SPHINCS⁺ SHAKE256 scheme with 128-bit security and 16976 byte signatures.
##

const
  n* = 16 ## The security parameter in bytes.
  h* = 60 ## The height of the hypertree.
  d* = 20 ## The number of layers in the hypertree.
  a* = 9 ## The log of the number of leaves of a FORS tree.
  k* = 30 ## The number of trees in FORS.
  w* = 16 ## The Winternitz parameter.

include private/sphincs_shake256

A  => src/sphincs/shake256_128s.nim +13 -0
@@ 1,13 @@
##
## Small SPHINCS⁺ SHAKE256 scheme with 133-bit security and 8080 byte signatures.
##

const
  n* = 16 ## The security parameter in bytes.
  h* = 64 ## The height of the hypertree.
  d* = 8 ## The number of layers in the hypertree.
  a* = 15 ## The log of the number of leaves of a FORS tree.
  k* = 10 ## The number of trees in FORS.
  w* = 16 ## The Winternitz parameter.

include private/sphincs_shake256

A  => src/sphincs/shake256_192f.nim +13 -0
@@ 1,13 @@
##
## Fast SPHINCS⁺ SHAKE256 scheme with 194-bit security and 34664 byte signatures.
##

const
  n* = 24 ## The security parameter in bytes.
  h* = 66 ## The height of the hypertree.
  d* = 22 ## The number of layers in the hypertree.
  a* = 8 ## The log of the number of leaves of a FORS tree.
  k* = 33 ## The number of trees in FORS.
  w* = 16 ## The Winternitz parameter.

include ./private/sphincs_shake256

A  => src/sphincs/shake256_192s.nim +13 -0
@@ 1,13 @@
##
## Small SPHINCS⁺ SHAKE256 scheme with 196-bit security and 17064 byte signatures.
##

const
  n* = 24 ## The security parameter in bytes.
  h* = 64 ## The height of the hypertree.
  d* = 8 ## The number of layers in the hypertree.
  a* = 16 ## The log of the number of leaves of a FORS tree.
  k* = 14 ## The number of trees in FORS.
  w* = 16 ## The Winternitz parameter.

include ./private/sphincs_shake256

A  => src/sphincs/shake256_256f.nim +13 -0
@@ 1,13 @@
##
## Fast SPHINCS⁺ SHAKE256 scheme with 254-bit security and 49216 byte signatures.
##

const
  n* = 32 ## The security parameter in bytes.
  h* = 68 ## The height of the hypertree.
  d* = 17 ## The number of layers in the hypertree.
  a* = 10 ## The log of the number of leaves of a FORS tree.
  k* = 30 ## The number of trees in FORS.
  w* = 16 ## The Winternitz parameter.

include ./private/sphincs_shake256

A  => src/sphincs/shake256_256s.nim +13 -0
@@ 1,13 @@
##
## Small SPHINCS⁺ SHAKE256 scheme with 255-bit security and 29792 byte signatures.
##

const
  n* = 32 ## The security parameter in bytes.
  h* = 64 ## The height of the hypertree.
  d* = 8 ## The number of layers in the hypertree.
  a* = 14 ## The log of the number of leaves of a FORS tree.
  k* = 22 ## The number of trees in FORS.
  w* = 16 ## The Winternitz parameter.

include ./private/sphincs_shake256

A  => tests/hex.nim +74 -0
@@ 1,74 @@
#[
The MIT License (MIT)

Copyright (c) 2014 Eric S. Bullington

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
]#

proc nibbleFromChar(c: char): int =
  case c
  of '0'..'9': result = (ord(c) - ord('0'))
  of 'a'..'f': result = (ord(c) - ord('a') + 10)
  of 'A'..'F': result = (ord(c) - ord('A') + 10)
  else:
    raise newException(ValueError, "invalid hexadecimal encoding")

proc decode*[T: char|int8|uint8](str: string; result: var openArray[T]) =
  assert(result.len == str.len div 2)
  for i in 0..<result.len:
    result[i] = T((nibbleFromChar(str[2 * i]) shl 4) or nibbleFromChar(str[2 * i + 1]))

proc decode*(str: string): string =
  result = newString(len(str) div 2)
  decode(str, result)

proc nibbleToChar(nibble: int): char =
  const byteMap = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']
  const byteMapLen = len(byteMap)
  if nibble < byteMapLen:
    return byteMap[nibble];

template encodeTmpl(str: untyped): typed =
  let length = (len(str))
  result = newString(length * 2)
  for i in str.low..str.high:
    let a = ord(str[i]) shr 4
    let b = ord(str[i]) and ord(0x0f)
    result[i * 2] = nibbleToChar(a)
    result[i * 2 + 1] = nibbleToChar(b)

proc encode*(bin: string): string =
  encodeTmpl(bin)

proc encode*(bin: openarray[char|int8|uint8]): string =
  encodeTmpl(bin)

when isMainModule:
  assert encode("The sun so bright it leaves no shadows") == "5468652073756e20736f20627269676874206974206c6561766573206e6f20736861646f7773"
  const longText = """Man is distinguished, not only by his reason, but by this
    singular passion from other animals, which is a lust of the mind,
    that by a perseverance of delight in the continued and indefatigable
    generation of knowledge, exceeds the short vehemence of any carnal
    pleasure."""
  assert encode(longText) == "4d616e2069732064697374696e677569736865642c206e6f74206f6e6c792062792068697320726561736f6e2c2062757420627920746869730a2020202073696e67756c61722070617373696f6e2066726f6d206f7468657220616e696d616c732c2077686963682069732061206c757374206f6620746865206d696e642c0a20202020746861742062792061207065727365766572616e6365206f662064656c6967687420696e2074686520636f6e74696e75656420616e6420696e6465666174696761626c650a2020202067656e65726174696f6e206f66206b6e6f776c656467652c2065786365656473207468652073686f727420766568656d656e6365206f6620616e79206361726e616c0a20202020706c6561737572652e"
  const tests = ["", "abc", "xyz", "man", "leisure.", "sure.", "erasure.",
                 "asure.", longText]
  for t in items(tests):
    assert decode(encode(t)) == t

A  => tests/nim.cfg +1 -0
@@ 1,1 @@
-d:release --path:"../src/" --cincludes:"."

A  => tests/rng.c +222 -0
@@ 1,222 @@
//
//  rng.c
//
//  Created by Bassham, Lawrence E (Fed) on 8/29/17.
//  Copyright © 2017 Bassham, Lawrence E (Fed). All rights reserved.
//

#include <string.h>
#include "rng.h"
#include <openssl/conf.h>
#include <openssl/evp.h>
#include <openssl/err.h>

AES256_CTR_DRBG_struct  DRBG_ctx;

void    AES256_ECB(unsigned char *key, unsigned char *ctr, unsigned char *buffer);

/*
 seedexpander_init()
 ctx            - stores the current state of an instance of the seed expander
 seed           - a 32 byte random value
 diversifier    - an 8 byte diversifier
 maxlen         - maximum number of bytes (less than 2**32) generated under this seed and diversifier
 */
int
seedexpander_init(AES_XOF_struct *ctx,
                  unsigned char *seed,
                  unsigned char *diversifier,
                  unsigned long maxlen)
{
    if ( maxlen >= 0x100000000 )
        return RNG_BAD_MAXLEN;
    
    ctx->length_remaining = maxlen;
    
    memcpy(ctx->key, seed, 32);
    
    memcpy(ctx->ctr, diversifier, 8);
    ctx->ctr[11] = maxlen % 256;
    maxlen >>= 8;
    ctx->ctr[10] = maxlen % 256;
    maxlen >>= 8;
    ctx->ctr[9] = maxlen % 256;
    maxlen >>= 8;
    ctx->ctr[8] = maxlen % 256;
    memset(ctx->ctr+12, 0x00, 4);
    
    ctx->buffer_pos = 16;
    memset(ctx->buffer, 0x00, 16);
    
    return RNG_SUCCESS;
}

/*
 seedexpander()
    ctx  - stores the current state of an instance of the seed expander
    x    - returns the XOF data
    xlen - number of bytes to return
 */
int
seedexpander(AES_XOF_struct *ctx, unsigned char *x, unsigned long xlen)
{
    unsigned long   offset;
    
    if ( x == NULL )
        return RNG_BAD_OUTBUF;
    if ( xlen >= ctx->length_remaining )
        return RNG_BAD_REQ_LEN;
    
    ctx->length_remaining -= xlen;
    
    offset = 0;
    while ( xlen > 0 ) {
        if ( xlen <= (16-ctx->buffer_pos) ) { // buffer has what we need
            memcpy(x+offset, ctx->buffer+ctx->buffer_pos, xlen);
            ctx->buffer_pos += xlen;
            
            return RNG_SUCCESS;
        }
        
        // take what's in the buffer
        memcpy(x+offset, ctx->buffer+ctx->buffer_pos, 16-ctx->buffer_pos);
        xlen -= 16-ctx->buffer_pos;
        offset += 16-ctx->buffer_pos;
        
        AES256_ECB(ctx->key, ctx->ctr, ctx->buffer);
        ctx->buffer_pos = 0;
        
        //increment the counter
        for (int i=15; i>=12; i--) {
            if ( ctx->ctr[i] == 0xff )
                ctx->ctr[i] = 0x00;
            else {
                ctx->ctr[i]++;
                break;
            }
        }
        
    }
    
    return RNG_SUCCESS;
}


void handleErrors(void)
{
    ERR_print_errors_fp(stderr);
    abort();
}

// Use whatever AES implementation you have. This uses AES from openSSL library
//    key - 256-bit AES key
//    ctr - a 128-bit plaintext value
//    buffer - a 128-bit ciphertext value
void
AES256_ECB(unsigned char *key, unsigned char *ctr, unsigned char *buffer)
{
    EVP_CIPHER_CTX *ctx;
    
    int len;
    
    int ciphertext_len;
    
    /* Create and initialise the context */
    if(!(ctx = EVP_CIPHER_CTX_new())) handleErrors();
    
    if(1 != EVP_EncryptInit_ex(ctx, EVP_aes_256_ecb(), NULL, key, NULL))
        handleErrors();
    
    if(1 != EVP_EncryptUpdate(ctx, buffer, &len, ctr, 16))
        handleErrors();
    ciphertext_len = len;
    
    /* Clean up */
    EVP_CIPHER_CTX_free(ctx);
}

void
randombytes_init(unsigned char *entropy_input,
                 unsigned char *personalization_string,
                 int security_strength)
{
    unsigned char   seed_material[48];
    
    memcpy(seed_material, entropy_input, 48);
    if (personalization_string)
        for (int i=0; i<48; i++)
            seed_material[i] ^= personalization_string[i];
    memset(DRBG_ctx.Key, 0x00, 32);
    memset(DRBG_ctx.V, 0x00, 16);
    AES256_CTR_DRBG_Update(seed_material, DRBG_ctx.Key, DRBG_ctx.V);
    DRBG_ctx.reseed_counter = 1;
}

int
randombytes(unsigned char *x, unsigned long long xlen)
{
    unsigned char   block[16];
    int             i = 0;
    
    while ( xlen > 0 ) {
        //increment V
        for (int j=15; j>=0; j--) {
            if ( DRBG_ctx.V[j] == 0xff )
                DRBG_ctx.V[j] = 0x00;
            else {
                DRBG_ctx.V[j]++;
                break;
            }
        }
        AES256_ECB(DRBG_ctx.Key, DRBG_ctx.V, block);
        if ( xlen > 15 ) {
            memcpy(x+i, block, 16);
            i += 16;
            xlen -= 16;
        }
        else {
            memcpy(x+i, block, xlen);
            xlen = 0;
        }
    }
    AES256_CTR_DRBG_Update(NULL, DRBG_ctx.Key, DRBG_ctx.V);
    DRBG_ctx.reseed_counter++;
    
    return RNG_SUCCESS;
}

void
AES256_CTR_DRBG_Update(unsigned char *provided_data,
                       unsigned char *Key,
                       unsigned char *V)
{
    unsigned char   temp[48];
    
    for (int i=0; i<3; i++) {
        //increment V
        for (int j=15; j>=0; j--) {
            if ( V[j] == 0xff )
                V[j] = 0x00;
            else {
                V[j]++;
                break;
            }
        }
        
        AES256_ECB(Key, V, temp+16*i);
    }
    if ( provided_data != NULL )
        for (int i=0; i<48; i++)
            temp[i] ^= provided_data[i];
    memcpy(Key, temp, 32);
    memcpy(V, temp+32, 16);
}










A  => tests/rng.h +55 -0
@@ 1,55 @@
//
//  rng.h
//
//  Created by Bassham, Lawrence E (Fed) on 8/29/17.
//  Copyright © 2017 Bassham, Lawrence E (Fed). All rights reserved.
//

#ifndef rng_h
#define rng_h

#include <stdio.h>

#define RNG_SUCCESS      0
#define RNG_BAD_MAXLEN  -1
#define RNG_BAD_OUTBUF  -2
#define RNG_BAD_REQ_LEN -3

typedef struct {
    unsigned char   buffer[16];
    int             buffer_pos;
    unsigned long   length_remaining;
    unsigned char   key[32];
    unsigned char   ctr[16];
} AES_XOF_struct;

typedef struct {
    unsigned char   Key[32];
    unsigned char   V[16];
    int             reseed_counter;
} AES256_CTR_DRBG_struct;


void
AES256_CTR_DRBG_Update(unsigned char *provided_data,
                       unsigned char *Key,
                       unsigned char *V);

int
seedexpander_init(AES_XOF_struct *ctx,
                  unsigned char *seed,
                  unsigned char *diversifier,
                  unsigned long maxlen);

int
seedexpander(AES_XOF_struct *ctx, unsigned char *x, unsigned long xlen);

void
randombytes_init(unsigned char *entropy_input,
                 unsigned char *personalization_string,
                 int security_strength);

int
randombytes(unsigned char *x, unsigned long long xlen);

#endif /* rng_h */

A  => tests/test.nim +99 -0
@@ 1,99 @@
import sphincs/shake256_128f
import sphincs/shake256_128s
import sphincs/shake256_192f
import sphincs/shake256_192s
import sphincs/shake256_256f
import sphincs/shake256_256s

import parseutils, strutils, strtabs, unittest
import ./hex

{.compile: "rng.c".}
{.passL: "-lcrypto".}

proc randombytes_init(entropy_input, personalization_string: ptr cuchar;
                      security_strength: cint) {.importc, header:"rng.h".}
  ## Initialize the reference RNG.

proc randombytes(x: ptr cuchar; xlen: culonglong): cint {.importc, header:"rng.h".}
  ## Collect entropy from the reference RNG.

proc randomBytes(p: pointer; size: Natural) =
  let r = randombytes(cast[ptr cuchar](p), (culonglong)size)
  doAssert(r == 0, "reference randombytes failed")

proc zeroBytes(p: pointer; size: Natural) =
  zeroMem(p, size)

proc parseHex(result: var string; val: string) =
  result.setLen(val.len div 2)
  hex.decode(val, result)

proc parseHex(result: var seq[byte]; val: string) =
  result.setLen(val.len div 2)
  hex.decode(val, result)

template katTest(path: string; keyGen: untyped) =
  ## Use the keyGen procedure to select the scheme implementation.
  suite path:
    var
      count, mlen, smlen: int
      key = ""
      val = ""
      msg = ""
      sm = ""
      buf = newSeq[byte]()
      pair = keyGen(zeroBytes)
    for line in lines(path):
      if line == "":
        key.setLen(0)
        val.setLen(0)
        buf.setLen(0)
        pair = keyGen(zeroBytes)
      else:
        key.setLen(0)
        let off = line.parseUntil(key, " = ")
        if not key.validIdentifier:
          continue
        discard line.parseWhile(val, HexDigits, off+3)
        case key
        of "count":
          count = parseInt val
        of "seed":
          buf.parseHex(val)
          doAssert(buf.len == 48)
          randombytes_init(cast[ptr cuchar](buf[0].addr), nil, 256)
          pair = keyGen(randombytes)
        of "pk":
          buf.parseHex(val)
          doAssert(equalMem(pair.pk.addr, buf[0].addr, buf.len))
        of "sk":
          buf.parseHex(val)
          doAssert(equalMem(pair.sk.addr, buf[0].addr, buf.len))
        of "mlen":
          mlen = parseInt val
        of "msg":
          msg.parseHex(val)
          doAssert(msg.len == mlen)
        of "smlen":
          smlen = parseInt val
        of "sm":
          # verification is faster than signing, so do that first
          sm.parseHex(val)
          doAssert(sm.len == smlen)
          test "verify " & $count:
            let (valid, M) = pair.pk.verify(sm)
            doAssert(valid)
            doAssert(M == msg)
          test "sign " & $count:
            let sig = pair.sign(msg, randombytes)
            doAssert(sig == sm)
        else:
          discard

katTest("tests/shake256-128f/PQCsignKAT_64.rsp", shake256_128f.generateKeypair)
katTest("tests/shake256-128s/PQCsignKAT_64.rsp", shake256_128s.generateKeypair)
katTest("tests/shake256-192f/PQCsignKAT_96.rsp", shake256_192f.generateKeypair)
katTest("tests/shake256-192s/PQCsignKAT_96.rsp", shake256_192s.generateKeypair)
katTest("tests/shake256-256f/PQCsignKAT_128.rsp", shake256_256f.generateKeypair)
katTest("tests/shake256-256s/PQCsignKAT_128.rsp", shake256_256s.generateKeypair)