77434e30587cd8e72d9a6dbbed1a56348667821b — Martin Angers 6 months ago 3094aaa
add peek to stack
1 files changed, 37 insertions(+), 40 deletions(-)

M zerojson.go
M zerojson.go => zerojson.go +37 -40
@@ 2,7 2,7 @@ package zerojson
 
 import (
 	"bytes"
-	"errors"
+	"fmt"
 )
 
 // JSON supports 7 different values:


@@ 36,10 36,10 @@ const (
 
 type stack struct {
 	depth int
+	cur   byte
 
 	// each uint64 can store 64 levels deep (1 bit is sufficient per level)
 	static [staticStackSize]uint64
-
 	// for very deeply-nested JSON, resort to allocation
 	dynamic []uint64
 }


@@ 53,6 53,7 @@ func (s *stack) push(v byte) {
 	wordIndex := s.depth / 64
 	bitIndex := s.depth % 64
 	s.depth++
+	s.cur = v
 
 	// set bit to 1 for Object, 0 for Array, by dividing the byte
 	// by '{'. That is:


@@ 72,10 73,14 @@ func (s *stack) push(v byte) {
 	s.dynamic[wordIndex] |= bit << uint(bitIndex)
 }
 
-func (s *stack) pop() byte {
-	s.depth--
-	wordIndex := s.depth / 64
-	bitIndex := s.depth % 64
+func (s *stack) peek() byte {
+	if s.cur > 0 {
+		return s.cur
+	}
+
+	ix := s.depth - 1
+	wordIndex := ix / 64
+	bitIndex := ix % 64
 
 	var word uint64
 	if wordIndex < staticStackSize {


@@ 85,9 90,18 @@ func (s *stack) pop() byte {
 		word = s.dynamic[wordIndex]
 	}
 	if word&(1<<uint(bitIndex)) == 0 {
-		return '['
+		s.cur = '['
+	} else {
+		s.cur = '{'
 	}
-	return '{'
+	return s.cur
+}
+
+func (s *stack) pop() byte {
+	b := s.peek()
+	s.depth--
+	s.cur = 0 // unset, will need to be read on peek
+	return b
 }
 
 type parser struct {


@@ 128,72 142,55 @@ func (p *parser) parse() error {
 
 		var (
 			typ byte
-			val []byte
 			err error
 		)
 
 		start := p.pos
+		typ = p.cur
 		switch p.cur {
 		case '{':
 			p.stack.push(p.cur)
 			p.advance()
 
-			typ = '{'
-			val = p.input[start:p.pos]
-
 		case '[':
 			p.stack.push(p.cur)
 			p.advance()
 
-			typ = '['
-			val = p.input[start:p.pos]
-
 		case '"':
-			p.scanString()
-			typ = '"'
-			val = p.input[start:p.pos]
+			err = p.scanString()
 
 		case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
-			p.scanNumber()
 			typ = '1'
-			val = p.input[start:p.pos]
+			err = p.scanNumber()
 
 		case 't':
-			if e := p.scanTrue(); e != nil {
-				typ = '!'
-				err = e
-			} else {
-				typ = 't'
-			}
-			val = p.input[start:p.pos]
-
+			err = p.scanLiteral(trueTrail)
 		case 'f':
-			// false
-
+			err = p.scanLiteral(falseTrail)
 		case 'n':
-			// null
+			err = p.scanLiteral(nullTrail)
 
 		default:
-			err = errors.New("invalid character")
-			typ = '!'
-			val = p.input[start:p.pos]
+			err = fmt.Errorf("invalid character: %#U", p.cur)
 			p.advance()
 		}
 
-		if e := p.emit(start, typ, val, err); e != nil {
+		if err != nil {
+			typ = '!'
+		}
+		if e := p.emit(start, typ, p.input[start:p.pos], err); e != nil {
 			return e
 		}
 	}
 	return nil
 }
 
-func (p *parser) scanTrue() error {
-	if !bytes.HasPrefix(p.input[p.pos+1:], []byte(trueTrail)) {
-		p.advance()
-		return errors.New("nvalid character")
+func (p *parser) scanLiteral(trail string) error {
+	if !bytes.HasPrefix(p.input[p.pos:], []byte(trail)) {
+		return ("nvalid character")
 	}
 
-	p.pos += len(trueTrail)
+	p.pos += len(trail)
 	p.advance()
 	return nil
 }