~ehmry/dhall-nim

69f135ca13191d9c19e7e54a0348294c819acc1a — Emery Hemingway 1 year, 1 month ago 3cf4f2b
Binary encoding and decoding
4 files changed, 515 insertions(+), 174 deletions(-)

M dhall.nimble
A src/dhall/binary.nim
M src/dhall/term.nim
A tests/test1.nim
M dhall.nimble => dhall.nimble +1 -1
@@ 10,4 10,4 @@ srcDir        = "src"

# Dependencies

requires "nim >= 1.2.0", "bigints"
requires "nim >= 1.2.0", "bigints", "cbor"

A src/dhall/binary.nim => src/dhall/binary.nim +481 -0
@@ 0,0 1,481 @@
import ./term

import cbor, cbor/bignum, bigints

import std/algorithm, std/math, std/options, std/streams, std/tables

proc writeCbor*(s: Stream; t: Term)

proc `%`*(k: TermKind): CborNode = %BiggestUint(k)

proc `%`*(t: Term): CborNode = initCborOther t

proc writeCbor*(s: Stream; table: Table[string, Term]) =
  var keys = newSeqOfCap[string](table.len)
  for k in table.keys:
    keys.add k
  sort(keys)
  s.writeMapLen(keys.len)
  for k in keys:
    s.writeCbor(k)
    s.writeCbor(table[k])

proc writeCbor*(s: Stream; tk: TermKind) {.inline.} = s.writeCbor(tk.ord)

proc writeCbor*(s: Stream; t: Term) =
  # TODO: macro this stuff
  template wr(x: untyped) = writeCbor(s, x)
  template wrL(n: Natural) = writeArrayLen(s, n)
  if t.isNil:
    wr(%nil.pointer)
    return
  case t.kind
  of tVar:
    if t.varName == "_":
      wr t.varIndex
    else:
      s.writeArrayLen 2
      wr t.varName
      wr t.varIndex
  of tUniverse:
    wr $t.universe
  of tBuiltin:
    wr $t.builtin
  of tApp:
    wrL(2 + t.appArgs.len)
    wr t.kind
    wr t.appFun
    for arg in t.appArgs:
      wr arg
  of tLambda:
    if t.lambdaLabel == "_":
      wrL 3
      wr t.kind
      wr t.lambdaType
      wr t.lambdaBody
    else:
      wrL 4
      wr t.kind
      wr t.lambdaLabel
      wr t.lambdaType
      wr t.lambdaBody
  of tPi:
    if t.piLabel == "_":
      wrL 3
      wr t.kind
      wr t.piType
      wr t.piBody
    else:
      wrL 4
      wr t.kind
      wr t.piLabel
      wr t.piType
      wr t.piBody
  of tOp:
    wrL 4
    wr t.kind
    wr ord(t.op)
    wr t.opL
    wr t.opR
  of tList:
    wrL(2+t.list.len)
    wr t.kind
    if t.list.len == 0:
      wr t.listType
    else:
      wr %nil.pointer
      for x in t.list:
        wr x
  of tSome:
    wrL 3
    wr t.kind
    if t.someType.isNil:
      wr %nil.pointer
    else:
      wr t.someType
    wr t.someVal
  of tMerge:
    if t.mergeAnn.isNil:
      wrL 3
      wr t.kind
      wr t.mergeHandler
      wr t.mergeUnion
    else:
      wrL 4
      wr t.kind
      wr t.mergeHandler
      wr t.mergeUnion
      wr t.mergeAnn
  of tRecordType:
    wrL 2
    wr t.kind
    wr t.recordType
  of tRecordLiteral:
    wrL 2
    wr t.kind
    wr t.recordLiteral
  of tField:
    wrL 3
    wr t.kind
    wr t.fieldRecord
    wr t.fieldName
  of tProject:
    wrL(2 + t.projectNames.len)
    wr t.kind
    wr t.projectRecord
    for pn in t.projectNames:
      wr pn
  of tProjectType:
    wrL 3
    wr tProject
    wr t.projectTypeRecord
    wrL 1
    wr t.projectTypeSelector
  of tUnionType:
    wrL 2
    wr t.kind
    wr t.union
  of tBoolLiteral:
    wr t.bool
  of tIf:
    wrL 4
    wr t.kind
    wr t.ifCond
    wr t.ifTrue
    wr t.ifFalse
  of tNaturalLiteral:
    wrL 2
    wr t.kind
    wr t.natural
  of tIntegerLiteral:
    wrL 2
    wr t.kind
    wr t.integer
  of tDoubleLiteral:
    wr t.double
  of tTextLiteral:
    wrL(1+t.textChunks.len*2+1)
    wr t.kind
    for ch in t.textChunks:
      wr ch.textPrefix
      wr ch.textExpr
    wr t.textSuffix
  of tAssert:
    wrL 2
    wr t.kind
    wr t.assertAnn
  of tImport:
    let check = if t.importCheck == @[]: %nil.pointer
      else: %t.importCheck
    var tmp = %[%t.kind, check, %t.importKind.uint8, %t.importScheme]
    if t.importScheme in {0, 1}:
      if t.importHeaders.isNil:
        tmp.seq.add(%nil.pointer)
      else:
        tmp.seq.add(%t.importHeaders)
    for i, elem in t.importElements:
      tmp.seq.add(%elem)
    if t.importScheme in {0, 1}:
      if t.importQuery.isSome:
        tmp.seq.add(%t.importQuery.get)
      else:
        tmp.seq.add(%nil.pointer)
    wr tmp
  of tLet:
    wrL(2+t.letBinds.len*3)
    wr t.kind
    for b in t.letBinds:
      wr b.key
      if b.ann.isNil:
        wr %nil.pointer
      else:
        wr b.ann
      wr b.val
    wr t.letBody
  of tAnnotation:
    wrL 3
    wr t.kind
    wr t.annExpr
    wr t.annAnn
  of tToMap:
    if t.toMapAnn.isNil:
      wrL 2
      wr t.kind
      wr t.toMapBody
    else:
      wrL 3
      wr t.kind
      wr t.toMapBody
      wr t.toMapAnn
  of tEmptyList:
    wrL 2
    wr t.kind
    wr t.emptyListType
  else:
    assert(false, $t.kind & " escaped from parser")

proc encode*(expr: Term): string =
  let str = newStringStream()
  str.writeCbor(expr)
  str.data

proc parseAssert(check: bool; msg = "invalid Dhall encoding") {.inline.} =
  if not check: raise newException(CborParseError, msg)

proc nextBytesOrNil(c: var CborParser): seq[byte] =
  case c.kind
  of CborEventKind.cborBytes:
    result = c.nextBytes()
  of CborEventKind.cborSimple:
    let node = c.nextNode()
    parseAssert(node.isNull)
  else:
    parseAssert(false)

proc nextTextOrNil(c: var CborParser): string =
  case c.kind
  of CborEventKind.cborText:
    result = c.nextText()
  of CborEventKind.cborSimple:
    let node = c.nextNode()
    parseAssert(node.isNull)
  else:
    parseAssert(false)

proc nextTerm(parser: var CborParser): Term

proc nextTable(parser: var CborParser): Table[string, Term] =
  let tableLen = parser.mapLen()
  result = initTable[string, Term](tableLen.nextPowerOfTwo)
  parser.next()
  for n in 1 .. tableLen:
    let key = parser.nextText()
    result[key] = parser.nextTerm()

proc nextTerm(parser: var CborParser): Term =
  if parser.kind == CborEventKind.cborTag:
    parser.next()
  case parser.kind
  of CborEventKind.cborEof:
    raise newException(EOFError, "end of CBOR data")
  of CborEventKind.cborArray:
    let arrayLen = parser.arrayLen
    parser.next()
    if parser.kind == CborEventKind.cborTag:
      parser.next()
    case parser.kind
    of cborPositive:
      let kind = TermKind(parser.nextUInt())
      case kind

      of tApp:
        parseAssert(arrayLen > 2)
        let argsLen = arrayLen - 2
        result = Term(kind: kind, appFun: parser.nextTerm(),
            appArgs: newSeq[Term](argsLen))
        for m in result.appArgs.mitems:
          m = parser.nextTerm()

      of tLambda:
        case arrayLen:
        of 3:
          result = Term(kind: kind, lambdaLabel: "_",
              lambdaType: parser.nextTerm(), lambdaBody: parser.nextTerm())
        of 4:
          result = Term(kind: kind, lambdaLabel: parser.nextText(),
              lambdaType: parser.nextTerm(), lambdaBody: parser.nextTerm())
        else:
          parseAssert(false)

      of tPi:
        if arrayLen == 3:
          result = Term(kind: kind, piLabel: "_", piType: parser.nextTerm(),
              piBody: parser.nextTerm())
        else:
          result = Term(kind: kind, piLabel: parser.nextText(),
              piType: parser.nextTerm(), piBody: parser.nextTerm())

      of tOp:
        parseAssert(arrayLen == 4)
        let op = parser.nextInt()
        parseAssert(op <= high(OpKind).BiggestInt)
        result = Term(kind: kind, op: op.OpKind, opL: parser.nextTerm(), opR: parser.nextTerm())

      of tList:
        parseAssert(arrayLen >= 2)
        result = Term(kind: kind, list: newSeq[Term](arrayLen-2), listType: parser.nextTerm())
        for m in result.list.mitems:
          m = parser.nextTerm()
        parseAssert(
          (result.listType.isNil and result.list.len > 0) or
          (not result.listType.isNil and result.list.len == 0))

      of tSome:
        parseAssert(arrayLen == 3)
        result = Term(kind: kind, someType: parser.nextTerm(),
            someVal: parser.nextTerm())

      of tMerge:
        case arrayLen
        of 3:
          result = Term(kind: kind, mergeHandler: parser.nextTerm(),
              mergeUnion: parser.nextTerm())
        of 4:
          result = Term(kind: kind, mergeHandler: parser.nextTerm(),
              mergeUnion: parser.nextTerm(), mergeAnn: parser.nextTerm())
        else:
          parseAssert(false)

      of tRecordType:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, recordType: parser.nextTable())

      of tRecordLiteral:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, recordLiteral: parser.nextTable())

      of tField:
        parseAssert(arrayLen == 3)
        result = Term(kind: kind, fieldRecord: parser.nextTerm(),
            fieldname: parser.nextText())

      of tProject:
        parseAssert(arrayLen >= 3)
        let record = parser.nextTerm()
        if arrayLen == 3 and parser.kind == CborEventKind.cborArray:
          parser.next()
          result = Term(kind: tProjectType,
              projectTypeRecord: record,
              projectTypeSelector: parser.nextTerm())
        else:
          let namesLen = arrayLen - 2
          result = Term(kind: tProject,
              projectRecord: record,
              projectNames: newSeq[string](namesLen))
          for m in result.projectNames.mitems:
            m = parser.nextTextOrNil()

      of tUnionType:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, union: parser.nextTable())

      of tIf:
        parseAssert(arrayLen == 4)
        result = Term(kind: kind, ifCond: parser.nextTerm(),
            ifTrue: parser.nextTerm(), ifFalse: parser.nextTerm())

      of tNaturalLiteral:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, natural: parser.nextBigNum())
        parseAssert(Negative notin result.natural.flags)

      of tIntegerLiteral:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, integer: parser.nextBigNum())

      of tTextLiteral:
        parseAssert(arrayLen >= 2)
        let chunksLen = (arrayLen - 2) div 2
        result = Term(kind: kind, textChunks: newSeq[Term](chunksLen))
        for i in 0..result.textChunks.high:
          result.textChunks[i] = Term(kind: tTextChunk,
              textPrefix: parser.nextText(),
              textExpr: parser.nextTerm())
        result.textSuffix = parser.nextText()

      of tAssert:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, assertAnn: parser.nextTerm())

      of tImport:
        parseAssert(arrayLen >= 3)
        result = Term(kind: kind, importCheck: parser.nextBytesOrNil(),
            importKind: parser.nextInt().ImportKind,
            importScheme: parser.nextInt().uint8,
            importQuery: none(string))
        if result.importScheme in {0, 1}:
          result.importHeaders = parser.nextTerm()
          result.importElements = newSeq[string](arrayLen - 6)
          for m in result.importElements.mitems:
            m = parser.nextText()
          if parser.kind == CborEventKind.cborText:
            result.importQuery = some parser.nextText()
          else:
            doAssert(isNull parser.nextNode())
        elif result.importScheme == 7:
          discard
        else:
          result.importElements = newSeq[string](arrayLen - 4)
          for m in result.importElements.mitems:
            m = parser.nextText()

      of tLet:
        parseAssert(arrayLen >= 3)
        let bindsLen = (arrayLen - 2) div 3
        result = Term(kind: kind, letBinds: newSeq[Term](bindsLen))
        for m in result.letBinds.mitems:
          m = Term(kind: tBinding,
              key: parser.nextText(),
              ann: parser.nextTerm(),
              val: parser.nextTerm())
        result.letBody = parser.nextTerm()

      of tAnnotation:
        parseAssert(arrayLen == 3)
        result = Term(kind: kind, annExpr: parser.nextTerm(),
            annAnn: parser.nextTerm())

      of tToMap:
        case arrayLen
        of 2:
          result = Term(kind: kind, toMapBody: parser.nextTerm())
        of 3:
          result = Term(kind: kind, toMapBody: parser.nextTerm(),
              toMapAnn: parser.nextTerm())
        else:
          parseAssert(false)

      of tEmptyList:
        parseAssert(arrayLen == 2)
        result = Term(kind: kind, emptyListType: parser.nextTerm())

      else:
        parseAssert(false)

    of CborEventKind.cborText:
      result = Term(kind: tVar, varName: parser.nextText(),
          varIndex: parser.nextNode().uint.int)
      parseAssert(result.varname != "_" or result.varIndex != 0)

    else:
      parseAssert(false)

  of CborEventKind.cborPositive:
    result = Term(kind: tVar, varName: "_", varIndex: parser.nextUInt().int)

  of CborEventKind.cborText:
    result = Term(kind: tBuiltin, builtin: parser.nextText.parseBuiltin)

  of CborEventKind.cborSimple:
    let node = parser.nextNode()
    if node.isBool:
      result = Term(kind: tBoolLiteral, bool: node.getBool)
    elif node.isNull:
      result = nil
    else:
      parseAssert(false)

  of CborEventKind.cborFloat:
    result = Term(kind: tDoubleLiteral, double: parser.nextFloat())

  else:
    parseAssert(false)

proc decode*(str: Stream): Term =
  var parser: CborParser
  parser.open(str)
  parser.next()
  parser.nextTerm()

proc decode*(buf: string): Term =
  buf.newStringStream.decode

M src/dhall/term.nim => src/dhall/term.nim +0 -173
@@ 186,179 186,6 @@ type
      key*: string
      val*, ann* : Term

func `$`*(t: Term): string =
  if t.isNil: return "null"
  case t.kind:
  of tUniverse:
    result = $t.universe
  of tBuiltin:
    result = $t.builtin
  of tBoolLiteral:
    result = if t.bool: "True" else: "False"
  of tVar:
    if t.varIndex > 0:
      result = t.varName & "@" & $t.varIndex
    else:
      result = t.varName
  of tLambda:
    result = "(λ(" & t.lambdaLabel & " : " & $t.lambdaType & ") → " &
        $t.lambdaBody & ")"
  of tPi:
    if t.piLabel == "_":
      result = $t.piType & " → " & $t.piBody
    else:
      result = "∀(" & t.piLabel & " : " & $t.piType & ") → " & $t.piBody
  of tApp:
    result = "(" & $t.appFun
    for arg in t.appArgs:
      result.add(" ")
      result.add($arg)
    result.add(" )")
  of tOp:
    result = "(" & $t.opL & " " & $t.op & " " & $t.opR & ")"
  of tLet:
    result = ""
    for b in t.letBinds:
      result.add("let ")
      result.add(b.key)
      if not b.ann.isNil:
        result.add(": ")
        result.add($b.ann)
      result.add(" = ")
      result.add($b.val)
    result.add(" in ")
    result.add($t.letBody)
  of tAnnotation:
    result = $t.annExpr & " : " & $t.annAnn
  of tNaturalLiteral:
    result = $t.natural
  of tList:
    if t.list.len == 0:
      result = "[] : " & $t.listType
    else:
      result = "[ " & join(t.list, ", ") & " ]"
  of tTextLiteral:
    var hasNewline = t.textSuffix.contains({'\n'})
    for tc in t.textChunks:
      if hasNewline: break
      hasNewLine = tc.textPrefix.contains({'\n'})
    if hasNewline:
      result = "''\n"
      for tc in t.textChunks:
        result.add $tc
      result.add t.textSuffix
      result.add "''"
    else:
      result = "\""
      for tc in t.textChunks:
        result.add $tc
      result.add t.textSuffix
      result.add "\""
  of tAssert:
    result = "assert : " & $t.assertAnn
  of tImport:
    result = "import "
    result.add case t.importKind
      of iCode: ""
      of iText: "as Text "
      of iLocation: "as Location "
    result.add case t.importScheme
      of 0: "http://"
      of 1: "https://"
      of 2: "/"
      of 3: "./"
      of 4: "../"
      of 5: "~/"
      of 6: "env:"
      of 7: "missing"
      else: raiseAssert("unhandled import scheme " & $t.importScheme)
    result.add(join(t.importElements, "/"))
    if t.importQuery.isSome:
      result.add("?")
      result.add(t.importQuery.get)
    if not t.importCheck.len == 0:
      result.add " sha256:"
      for i in 2..31:
        result.add t.importCheck[i].toHex
  of tIf:
    result = "if " & $t.ifCond & " then " & $t.ifTrue & " else " & $t.ifFalse
  of tDoubleLiteral:
    result = $t.double
  of tIntegerLiteral:
    result = if t.integer > 0:
        "+" & $t.integer
      else: $t.integer
  of tSome:
    result = "Some " & $t.someVal
  of tRecordType:
    if t.recordType.len == 0:
      result = "{}"
    else:
      result = "{ "
      for key, val in t.recordType.pairs:
        if key == "" or key.startsWith(' ') or key.endsWith(' '):
          # TODO: builtin labels
          result.add "`"
          result.add(key)
          result.add "`"
        else:
          result.add(key)
        result.add(" : ")
        result.add($val)
        result.add(", ")
      result[^2] = ' '
      result[^1] = '}'
  of tRecordLiteral:
    if t.recordLiteral.len == 0:
      result = "{=}"
    else:
      result = "{ "
      for key, val in t.recordLiteral.pairs:
        result.add(key)
        result.add(" = ")
        result.add($val)
        result.add(", ")
      result[^2] = ' '
      result[^1] = '}'
  of tToMap:
    if t.toMapAnn.isNil:
      result = "toMap " & $t.toMapBody
    else:
      result = "toMap " & $t.toMapBody & " : " & $t.toMapAnn
  of tEmptyList:
    result = "[] : " & $t.emptyListType
  of tField:
    result = "(" & $t.fieldRecord & ")." & t.fieldName
  of tProject:
    result = $t.projectRecord & ".{" & join(t.projectNames, ", ") & "}"
  of tProjectType:
    result = $t.projectTypeRecord & ".(" & $t.projectTypeSelector & ")"
  of tUnionType:
    if t.union.len == 0:
      result = "<>"
    else:
      result = "< "
      for key, val in t.union.pairs:
        result.add(key)
        if not val.isNil:
          result.add(" : ")
          result.add($val)
        result.add(" | ")
      result[^2] = ' '
      result[^1] = '>'
  of tMerge:
    if t.mergeAnn.isNil:
      result = "merge (" & $t.mergeHandler & ") (" & $t.mergeUnion & ")"
    else:
      result = "merge (" & $t.mergeHandler & ") (" & $t.mergeUnion &
          ") : " & $t.mergeAnn
  of tEntry:
    result = "« " & $t.entryKey & " = " & $t.entryVal & " »"
  of tTextChunk:
    result = t.textPrefix & "${" & $t.textExpr & "}"
  of tBinding:
    result = "« " & $t.key & " = " & $t.val & " »"

func parseBuiltin*(s: string): BuiltinKind =
  case s
  of "Natural/build": bNaturalBuild

A tests/test1.nim => tests/test1.nim +33 -0
@@ 0,0 1,33 @@
import dhall/binary, dhall/render

import os, strutils, unittest

iterator dhallTests(testDir, suffix: string): string =
  for testPath in walkDirRec(testDir, relative = true):
    if testPath.endsWith suffix:
      var testBase = testPath
      testBase.setLen(testBase.len - suffix.len)
      yield testBase

suite "binary-decode":
  block success:
    let testDir = "dhall-lang/tests/binary-decode/success"
    for testBase in dhallTests(testDir, "A.dhallb"):
      test "success" / testBase:
        checkpoint(readFile(testDir / testBase & "A.diag"))
        let
          data = readFile(testDir / testBase & "A.dhallb")
          term = decode data
          cbor = encode term
          roundtrip = decode cbor
        check(not roundtrip.isNil)
  block failure:
    let testDir = "dhall-lang/tests/binary-decode/failure"
    for testBase in dhallTests(testDir, ".dhallb"):
      test "failure" / testBase:
        let data = readFile(testDir / testBase & ".dhallb")
        expect ValueError:
          let term = decode(data)
          let diag = readFile(testDir / testBase & ".diag")
          checkpoint(diag)
          checkpoint($term)