From 448201aba66654f37dd1db51b0d0038a36ed09b3 Mon Sep 17 00:00:00 2001 From: Paul Chiusano Date: Fri, 11 May 2018 17:14:06 -0400 Subject: [PATCH] fix at least one the bugs with Source/Sink buffer over/underflow, still have 1 failing test --- runtime-jvm/main/src/main/scala/Codecs.scala | 3 +- .../main/src/main/scala/util/Sink.scala | 22 ++++++----- .../main/src/main/scala/util/Source.scala | 24 +++++++++--- .../main/src/test/scala/CodecsTests.scala | 1 - .../src/test/scala/util/SourceSinkTests.scala | 37 ++++++++++++++++--- 5 files changed, 62 insertions(+), 25 deletions(-) diff --git a/runtime-jvm/main/src/main/scala/Codecs.scala b/runtime-jvm/main/src/main/scala/Codecs.scala index 73ae16823..7a36b6fea 100644 --- a/runtime-jvm/main/src/main/scala/Codecs.scala +++ b/runtime-jvm/main/src/main/scala/Codecs.scala @@ -38,8 +38,7 @@ object Codecs { def encodeNode(n: Node): Sequence[Array[Byte]] = { val fmt = nodeEncoder(n) - println(prettyFormat(fmt)) - Sink.toChunks(1024 * 1024 * 4) { sink => encodeSink(sink, fmt)(emitter) } + Sink.toChunks(4096) { sink => encodeSink(sink, fmt)(emitter) } } def encodeTerm(t: Term): Sequence[Array[Byte]] = encodeNode(Node.Term(t)) diff --git a/runtime-jvm/main/src/main/scala/util/Sink.scala b/runtime-jvm/main/src/main/scala/util/Sink.scala index 5021558ef..fd15ce05b 100644 --- a/runtime-jvm/main/src/main/scala/util/Sink.scala +++ b/runtime-jvm/main/src/main/scala/util/Sink.scala @@ -59,9 +59,8 @@ object Sink { bb.order(java.nio.ByteOrder.BIG_ENDIAN) - private final def fill = { - println("fill getting called " + position) - bb.flip() // + private final def empty = { + bb.flip() // reset position back to 0, set limit to position val buf = new Array[Byte](bb.limit()) pos += buf.length bb.get(buf) // this fills the array @@ -78,26 +77,29 @@ object Sink { // todo: more direct implementation putString(Text.toString(txt)) - // todo: this needs to split the array if buffer capacity is less than array length def put(bs: Array[Byte]) = - try { bb.put(bs); () } - catch { case e: BufferOverflowException => fill; bb.put(bs); () } + if (bs.length < bb.capacity()) + try { bb.put(bs); () } + catch { case e: BufferOverflowException => empty; bb.put(bs); () } + else bs.splitAt(bs.length / 2) match { + case (bs1, bs2) => put(bs1); put(bs2) + } def putByte(b: Byte) = try { bb.put(b); () } - catch { case e: BufferOverflowException => fill; bb.put(b); () } + catch { case e: BufferOverflowException => empty; bb.put(b); () } def putInt(n: Int) = try { bb.putInt(n); () } - catch { case e: BufferOverflowException => fill; bb.putInt(n); () } + catch { case e: BufferOverflowException => empty; bb.putInt(n); () } def putLong(n: Long) = try { bb.putLong(n); () } - catch { case e: BufferOverflowException => fill; bb.putLong(n); () } + catch { case e: BufferOverflowException => empty; bb.putLong(n); () } def putDouble(n: Double) = try { bb.putDouble(n); () } - catch { case e: BufferOverflowException => fill; bb.putDouble(n); () } + catch { case e: BufferOverflowException => empty; bb.putDouble(n); () } } def writeLong(n: Long): Array[Byte] = { diff --git a/runtime-jvm/main/src/main/scala/util/Source.scala b/runtime-jvm/main/src/main/scala/util/Source.scala index 504ccbf09..4379d2bfc 100644 --- a/runtime-jvm/main/src/main/scala/util/Source.scala +++ b/runtime-jvm/main/src/main/scala/util/Source.scala @@ -116,15 +116,16 @@ object Source { val bb = java.nio.ByteBuffer.allocate(bufferSize) var rem = chunks bb.limit(0) - Source.fromByteBuffer(bb, bb => rem.uncons match { + Source.fromByteBuffer(bb, (unread, bb) => rem.uncons match { case None => throw Underflow() case Some((chunk,chunks)) => - if (bb.limit() >= chunk.length) { + bb.put(unread) + if (chunk.length <= bb.remaining()) { bb.put(chunk) rem = chunks } else { // need to split up chunk - val (c1,c2) = chunk.splitAt(bb.limit()) + val (c1,c2) = chunk.splitAt(bb.remaining()) bb.put(c1) rem = c2 +: chunks } @@ -140,16 +141,24 @@ object Source { } } - def fromByteBuffer(bb: ByteBuffer, onEmpty: ByteBuffer => Unit): Source = new Source { + def fromByteBuffer(bb: ByteBuffer, onEmpty: (Array[Byte], ByteBuffer) => Unit): Source = new Source { bb.order(java.nio.ByteOrder.BIG_ENDIAN) var pos = 0L def position: Long = pos + bb.position().toLong def refill = { + // todo: gotta save the unread elements before calling onEmpty + val unread = + if (bb.remaining() > 0) { + val unread = new Array[Byte](bb.limit() - bb.position()) + bb.put(unread) + unread + } + else Array.empty[Byte] pos += bb.position() bb.clear() - onEmpty(bb) + onEmpty(unread, bb) bb.flip() } @@ -159,7 +168,10 @@ object Source { bb.get(arr) arr } - catch { case BufferUnderflow() => refill; get(n) } + catch { case BufferUnderflow() => + if (n <= bb.capacity()) { refill; get(n) } + else get(n/2) ++ get(n - n/2) + } def getByte: Byte = try bb.get diff --git a/runtime-jvm/main/src/test/scala/CodecsTests.scala b/runtime-jvm/main/src/test/scala/CodecsTests.scala index cc1ad30ec..53c389660 100644 --- a/runtime-jvm/main/src/test/scala/CodecsTests.scala +++ b/runtime-jvm/main/src/test/scala/CodecsTests.scala @@ -15,7 +15,6 @@ object CodecsTests { def roundTrip(p: Value) = { val bytes = Codecs.encodeValue(p) - // println(bytes.toList.flatten) Codecs.decodeValue(bytes) } diff --git a/runtime-jvm/main/src/test/scala/util/SourceSinkTests.scala b/runtime-jvm/main/src/test/scala/util/SourceSinkTests.scala index a45ed8845..06090e3d6 100644 --- a/runtime-jvm/main/src/test/scala/util/SourceSinkTests.scala +++ b/runtime-jvm/main/src/test/scala/util/SourceSinkTests.scala @@ -5,16 +5,41 @@ import org.unisonweb.EasyTest._ object SourceSinkTests { val tests = test("source/sink") { implicit T => - 1 until 100 foreach { n => - val input = byteArray(intIn(1,n+1*2)).toList - val bytes = Sink.toChunks(5) { sink => - input foreach { sink putByte _ } + // the sink should record all bytes, and the source + fail("this test hangs") + 1 until 10 foreach { n => + val input = byteArray(n).toVector + val A = intIn(1,n+1*2) + val bytes = Sink.toChunks(A) { sink => + var rem = input; while (rem.nonEmpty) { + println("1: " + math.random) + val n = intIn(1, rem.size + 1) + val (rem1,rem2) = rem.splitAt(n) + sink.put(rem1.toArray) + rem = rem2 + } + println("woot") } + println("bytes: " + bytes) + val B = intIn(1,n+1*2) + note("A: " + A, true) + note("B: " + B, true) val bytes2 = { - val src = Source.fromChunks(intIn(1,n+1*2))(bytes) - List.fill(input.size)(src.getByte) + val src = Source.fromChunks(B)(bytes) + var acc = Vector.empty[Byte] + var rem = input.size + while (acc.length != input.length) { + println(math.random) + println(acc.length + " " + input.length) + println + val n = intIn(1, rem + 1) + acc = acc ++ src.get(n).toVector + rem -= n + } + acc } equal1(bytes2, bytes.toList.flatten) + equal(bytes2, input) } ok }