Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package io.github.dfa1.vortex.reader.decode;

import io.github.dfa1.vortex.core.DType;
import io.github.dfa1.vortex.core.PType;
import io.github.dfa1.vortex.encoding.EncodingId;
import io.github.dfa1.vortex.proto.ALPMetadata;
import io.github.dfa1.vortex.proto.PatchesMetadata;
import io.github.dfa1.vortex.reader.ReadRegistry;
import io.github.dfa1.vortex.reader.array.DoubleArray;
import io.github.dfa1.vortex.reader.array.FloatArray;
import org.junit.jupiter.api.Test;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.within;

class AlpEncodingDecoderTest {

private static final AlpEncodingDecoder SUT = new AlpEncodingDecoder();
private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(SUT, new PrimitiveEncodingDecoder());

private static final DType F64 = new DType.Primitive(PType.F64, false);
private static final DType F32 = new DType.Primitive(PType.F32, false);

private static MemorySegment leLongs(long... vs) {
byte[] b = new byte[vs.length * 8];
ByteBuffer bb = ByteBuffer.wrap(b).order(ByteOrder.LITTLE_ENDIAN);
for (long v : vs) {
bb.putLong(v);
}
return MemorySegment.ofArray(b);
}

private static MemorySegment leInts(int... vs) {
byte[] b = new byte[vs.length * 4];
ByteBuffer bb = ByteBuffer.wrap(b).order(ByteOrder.LITTLE_ENDIAN);
for (int v : vs) {
bb.putInt(v);
}
return MemorySegment.ofArray(b);
}

private static MemorySegment leDoubles(double... vs) {
byte[] b = new byte[vs.length * 8];
ByteBuffer bb = ByteBuffer.wrap(b).order(ByteOrder.LITTLE_ENDIAN);
for (double v : vs) {
bb.putDouble(v);
}
return MemorySegment.ofArray(b);
}

@Test
void accepts_floatsTrue_otherFalse() {
// Given / When / Then
assertThat(SUT.accepts(F64)).isTrue();
assertThat(SUT.accepts(F32)).isTrue();
assertThat(SUT.accepts(new DType.Primitive(PType.I64, false))).isFalse();
assertThat(SUT.accepts(new DType.Utf8(false))).isFalse();
}

@Test
void decode_nonPrimitiveDtype_throws() {
// Given a Utf8 dtype on an ALP node
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(new ALPMetadata(0, 0, null).encode()),
new ArrayNode[0], new int[0]);
DecodeContext ctx = new DecodeContext(node, new DType.Utf8(false), 1,
new MemorySegment[0], REGISTRY, Arena.ofAuto());

// When / Then
assertThatThrownBy(() -> SUT.decode(ctx)).hasMessageContaining("expected primitive dtype");
}

@Test
void decode_missingMetadata_defaultsToZeroExponents() {
// Given no metadata — decoder falls back to exp_e=0, exp_f=0 (scale 1.0)
ArrayNode enc = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0});
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ALP, null, new ArrayNode[]{enc}, new int[0]);
DecodeContext ctx = new DecodeContext(node, F64, 2, new MemorySegment[]{leLongs(5L, 7L)}, REGISTRY, Arena.ofAuto());

// When
DoubleArray result = (DoubleArray) SUT.decode(ctx);

// Then
assertThat(result.getDouble(0)).isCloseTo(5.0, within(1e-9));
assertThat(result.getDouble(1)).isCloseTo(7.0, within(1e-9));
}

@Test
void decode_f64_broadcastNoPatches_returnsConstant() {
// Given a single encoded value but 4 logical rows (capacity < n) and no patches:
// the decoder broadcasts it into a constant array
ArrayNode enc = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0});
byte[] meta = new ALPMetadata(2, 0, null).encode(); // exp_e=2 -> *0.01
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(meta), new ArrayNode[]{enc}, new int[0]);
DecodeContext ctx = new DecodeContext(node, F64, 4, new MemorySegment[]{leLongs(123L)}, REGISTRY, Arena.ofAuto());

// When
DoubleArray result = (DoubleArray) SUT.decode(ctx);

// Then
assertThat(result.length()).isEqualTo(4);
for (int i = 0; i < 4; i++) {
assertThat(result.getDouble(i)).as("index %d", i).isCloseTo(1.23, within(1e-9));
}
}

@Test
void decode_f32_broadcastNoPatches_returnsConstant() {
// Given single value, 3 rows, no patches
ArrayNode enc = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0});
byte[] meta = new ALPMetadata(1, 0, null).encode(); // exp_e=1 -> *0.1
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(meta), new ArrayNode[]{enc}, new int[0]);
DecodeContext ctx = new DecodeContext(node, F32, 3, new MemorySegment[]{leInts(25)}, REGISTRY, Arena.ofAuto());

// When
FloatArray result = (FloatArray) SUT.decode(ctx);

// Then
assertThat(result.length()).isEqualTo(3);
for (int i = 0; i < 3; i++) {
assertThat(result.getFloat(i)).as("index %d", i).isCloseTo(2.5f, within(1e-6f));
}
}

@Test
void decode_f64_patches_withU8Indices() {
// Given patches whose index child uses U8 storage — exercises the U8 arm of
// readUnsigned (the encoder always emits U32 indices)
PatchesMetadata pm = new PatchesMetadata(1L, 0L, io.github.dfa1.vortex.proto.PType.U8, null, null, null);
byte[] meta = new ALPMetadata(2, 0, pm).encode(); // *0.01

ArrayNode enc = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0});
ArrayNode idx = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1});
ArrayNode val = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2});
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(meta),
new ArrayNode[]{enc, idx, val}, new int[0]);

MemorySegment idxSeg = MemorySegment.ofArray(new byte[]{1}); // patch row 1
MemorySegment[] segs = {leLongs(100L, 0L, 300L), idxSeg, leDoubles(9.0)};
DecodeContext ctx = new DecodeContext(node, F64, 3, segs, REGISTRY, Arena.ofAuto());

// When
DoubleArray result = (DoubleArray) SUT.decode(ctx);

// Then
assertThat(result.getDouble(0)).isCloseTo(1.0, within(1e-9));
assertThat(result.getDouble(1)).isCloseTo(9.0, within(1e-9)); // patched
assertThat(result.getDouble(2)).isCloseTo(3.0, within(1e-9));
}

@Test
void decode_patches_nonUnsignedIndexPtype_throws() {
// Given a signed (I32) patch-index ptype — readUnsigned rejects it
PatchesMetadata pm = new PatchesMetadata(1L, 0L, io.github.dfa1.vortex.proto.PType.I32, null, null, null);
byte[] meta = new ALPMetadata(2, 0, pm).encode();

ArrayNode enc = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0});
ArrayNode idx = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1});
ArrayNode val = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2});
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(meta),
new ArrayNode[]{enc, idx, val}, new int[0]);

MemorySegment[] segs = {leLongs(100L, 0L), leInts(1), leDoubles(9.0)};
DecodeContext ctx = new DecodeContext(node, F64, 2, segs, REGISTRY, Arena.ofAuto());

// When / Then
assertThatThrownBy(() -> SUT.decode(ctx)).hasMessageContaining("non-unsigned patch index ptype");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package io.github.dfa1.vortex.reader.decode;

import io.github.dfa1.vortex.core.DType;
import io.github.dfa1.vortex.core.PType;
import io.github.dfa1.vortex.encoding.EncodingId;
import io.github.dfa1.vortex.encoding.TestSegments;
import io.github.dfa1.vortex.encoding.TimeUnit;
import io.github.dfa1.vortex.proto.DateTimePartsMetadata;
import io.github.dfa1.vortex.reader.ReadRegistry;
import io.github.dfa1.vortex.reader.array.LongArray;
import org.junit.jupiter.api.Test;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class DateTimePartsEncodingDecoderTest {

private static final DateTimePartsEncodingDecoder SUT = new DateTimePartsEncodingDecoder();
private static final ReadRegistry REGISTRY = TestRegistry.ofDecoders(SUT, new PrimitiveEncodingDecoder());

private static final long SECONDS_PER_DAY = 86_400L;

private static ByteBuffer i64Meta() {
return ByteBuffer.wrap(new DateTimePartsMetadata(
io.github.dfa1.vortex.proto.PType.I64,
io.github.dfa1.vortex.proto.PType.I64,
io.github.dfa1.vortex.proto.PType.I64).encode());
}

private static DType timestampDType(TimeUnit unit, boolean nullable) {
ByteBuffer meta = ByteBuffer.allocate(3).order(ByteOrder.LITTLE_ENDIAN);
meta.put((byte) unit.ordinal());
meta.putShort((short) 0);
meta.flip();
return new DType.Extension("vortex.timestamp",
new DType.Primitive(PType.I64, nullable), meta, nullable);
}

/// Builds a context with three I64 part-children backed by the given segments.
private static DecodeContext ctx(ByteBuffer meta, DType dtype, long n,
MemorySegment days, MemorySegment seconds, MemorySegment subseconds) {
ArrayNode d = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0});
ArrayNode s = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1});
ArrayNode ss = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2});
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_DATETIMEPARTS, meta, new ArrayNode[]{d, s, ss}, new int[0]);
return new DecodeContext(node, dtype, n, new MemorySegment[]{days, seconds, subseconds}, REGISTRY, Arena.ofAuto());
}

@Test
void encodingId_isVortexDateTimeParts() {
// Given / When / Then
assertThat(SUT.encodingId()).isEqualTo(EncodingId.VORTEX_DATETIMEPARTS);
}

@Test
void decode_missingMetadata_throws() {
// Given a node with no metadata
ArrayNode node = ArrayNode.of(EncodingId.VORTEX_DATETIMEPARTS, null, new ArrayNode[0], new int[0]);
DecodeContext c = new DecodeContext(node, timestampDType(TimeUnit.Milliseconds, false), 1,
new MemorySegment[0], REGISTRY, Arena.ofAuto());

// When / Then
assertThatThrownBy(() -> SUT.decode(c)).hasMessageContaining("missing metadata");
}

@Test
void decode_milliseconds_reassemblesParts() {
// Given 1 day + 1h2m3s + 456ms split across the three parts
long ts = 86_400_000L + 3723L * 1000L + 456L;
DecodeContext c = ctx(i64Meta(), timestampDType(TimeUnit.Milliseconds, false), 1,
TestSegments.leLongs(1L), TestSegments.leLongs(3723L), TestSegments.leLongs(456L));

// When
LongArray result = (LongArray) SUT.decode(c);

// Then
assertThat(result.getLong(0)).isEqualTo(ts);
}

@Test
void decode_daysUnit_usesUnitsPerSecondOne() {
// Given a Days-unit extension: divisor() throws for Days, so the decoder must
// special-case it to unitsPerSecond=1 (days only, sub-day parts zero)
DecodeContext c = ctx(i64Meta(), timestampDType(TimeUnit.Days, false), 1,
TestSegments.leLongs(2L), TestSegments.leLongs(0L), TestSegments.leLongs(0L));

// When
LongArray result = (LongArray) SUT.decode(c);

// Then 2 days * 86400 s/day * 1 unit/s
assertThat(result.getLong(0)).isEqualTo(2L * SECONDS_PER_DAY);
}

@Test
void decode_nullableExtension_decodesNullableDaysChild() {
// Given a nullable extension dtype — the days child is decoded as nullable
DecodeContext c = ctx(i64Meta(), timestampDType(TimeUnit.Milliseconds, true), 1,
TestSegments.leLongs(0L), TestSegments.leLongs(0L), TestSegments.leLongs(0L));

// When
LongArray result = (LongArray) SUT.decode(c);

// Then
assertThat(result.getLong(0)).isZero();
}

@Test
void decode_extensionMissingTimeUnitMetadata_throws() {
// Given an extension whose metadata byte is absent
DType noUnit = new DType.Extension("vortex.timestamp",
new DType.Primitive(PType.I64, false), null, false);
DecodeContext c = ctx(i64Meta(), noUnit, 1,
TestSegments.leLongs(0L), TestSegments.leLongs(0L), TestSegments.leLongs(0L));

// When / Then
assertThatThrownBy(() -> SUT.decode(c)).hasMessageContaining("missing TimeUnit metadata");
}

@Test
void decode_nonExtensionDtype_throws() {
// Given a primitive (non-extension) logical type
DecodeContext c = ctx(i64Meta(), new DType.Primitive(PType.I64, false), 1,
TestSegments.leLongs(0L), TestSegments.leLongs(0L), TestSegments.leLongs(0L));

// When / Then
assertThatThrownBy(() -> SUT.decode(c)).hasMessageContaining("expected Extension dtype");
}
}
Loading