From 03ec262ca5a3afb310e101b7e88ee1701d4fbdd6 Mon Sep 17 00:00:00 2001 From: bourdon Date: Sat, 11 Oct 2025 13:19:26 +0800 Subject: [PATCH] feat: updates to btree/index/query/sstable/table --- btree.go | 339 +++++++++++++++++++++++++++++++----- btree_test.go | 463 +++++++++++++++++++++++++++++++++++++++++++++++++ index.go | 40 +++++ index_btree.go | 49 ++++++ query.go | 338 +++++++++++++++++++++++++++++++++++- sstable.go | 12 ++ table.go | 42 ++++- 7 files changed, 1229 insertions(+), 54 deletions(-) diff --git a/btree.go b/btree.go index 8a02c43..c3b9508 100644 --- a/btree.go +++ b/btree.go @@ -52,7 +52,7 @@ B+Tree 用于索引 SSTable 和 Index 文件,提供 O(log n) 查询性能。 [Header: 32B] [Keys: Key0(8B), Key1(8B), Key2(8B)] [Children: Child0(8B), Child1(8B), Child2(8B), Child3(8B)] - + 查询规则: - key < Key0 → Child0 - Key0 ≤ key < Key1 → Child1 @@ -144,27 +144,29 @@ func NewLeafNode() *BTreeNode { // Marshal 序列化节点到 4 KB // // 布局: -// [Header: 32B] -// [Keys: KeyCount * 8B] -// [Values: 取决于节点类型] -// - Internal: Children (KeyCount+1) * 8B -// - Leaf: 交错存储 (Offset, Size) 对,每对 12B,共 KeyCount * 12B +// +// [Header: 32B] +// [Keys: KeyCount * 8B] +// [Values: 取决于节点类型] +// - Internal: Children (KeyCount+1) * 8B +// - Leaf: 交错存储 (Offset, Size) 对,每对 12B,共 KeyCount * 12B // // 示例(叶子节点,KeyCount=3): -// Offset | Size | Content -// -------|------|---------------------------------- -// 0 | 1 | NodeType = 1 (Leaf) -// 1 | 2 | KeyCount = 3 -// 3 | 1 | Level = 0 -// 4 | 28 | Reserved -// 32 | 24 | Keys [100, 200, 300] -// 56 | 8 | DataOffset0 = 1000 -// 64 | 4 | DataSize0 = 50 -// 68 | 8 | DataOffset1 = 2000 -// 76 | 4 | DataSize1 = 60 -// 80 | 8 | DataOffset2 = 3000 -// 88 | 4 | DataSize2 = 70 -// 92 | 4004 | Padding (unused) +// +// Offset | Size | Content +// -------|------|---------------------------------- +// 0 | 1 | NodeType = 1 (Leaf) +// 1 | 2 | KeyCount = 3 +// 3 | 1 | Level = 0 +// 4 | 28 | Reserved +// 32 | 24 | Keys [100, 200, 300] +// 56 | 8 | DataOffset0 = 1000 +// 64 | 4 | DataSize0 = 50 +// 68 | 8 | DataOffset1 = 2000 +// 76 | 4 | DataSize1 = 60 +// 80 | 8 | DataOffset2 = 3000 +// 88 | 4 | DataSize2 = 70 +// 92 | 4004 | Padding (unused) func (n *BTreeNode) Marshal() []byte { buf := make([]byte, BTreeNodeSize) @@ -213,10 +215,12 @@ func (n *BTreeNode) Marshal() []byte { // UnmarshalBTree 从字节数组反序列化节点 // // 参数: -// data: 4KB 节点数据(通常来自 mmap) +// +// data: 4KB 节点数据(通常来自 mmap) // // 返回: -// *BTreeNode: 反序列化后的节点 +// +// *BTreeNode: 反序列化后的节点 // // 零拷贝优化: // - 直接从 mmap 数据读取,不复制整个节点 @@ -310,15 +314,16 @@ func (n *BTreeNode) AddData(key int64, offset int64, size int32) error { // BTreeBuilder 从下往上构建 B+Tree // // 构建流程: -// 1. Add(): 添加所有 (key, offset, size) 到叶子节点 -// - 当叶子节点满时,创建新的叶子节点 -// - 所有叶子节点按 key 有序 // -// 2. Build(): 从叶子层向上构建 -// - Level 0: 叶子节点(已创建) -// - Level 1: 为叶子节点创建父节点(内部节点) -// - Level 2+: 递归创建更高层级 -// - 最终返回根节点偏移量 +// 1. Add(): 添加所有 (key, offset, size) 到叶子节点 +// - 当叶子节点满时,创建新的叶子节点 +// - 所有叶子节点按 key 有序 +// +// 2. Build(): 从叶子层向上构建 +// - Level 0: 叶子节点(已创建) +// - Level 1: 为叶子节点创建父节点(内部节点) +// - Level 2+: 递归创建更高层级 +// - 最终返回根节点偏移量 // // 示例(100 个 key,Order=200): // - 叶子层: 1 个叶子节点(100 个 key) @@ -453,13 +458,13 @@ func (b *BTreeBuilder) buildLevel(children []*BTreeNode, childOffsets []int64, l // BTreeReader 用于查询 B+Tree (mmap) // // 查询流程: -// 1. 从根节点开始 -// 2. 如果是内部节点: -// - 二分查找确定子节点 -// - 跳转到子节点继续查找 -// 3. 如果是叶子节点: -// - 二分查找 key -// - 返回 (dataOffset, dataSize) +// 1. 从根节点开始 +// 2. 如果是内部节点: +// - 二分查找确定子节点 +// - 跳转到子节点继续查找 +// 3. 如果是叶子节点: +// - 二分查找 key +// - 返回 (dataOffset, dataSize) // // 性能优化: // - mmap 零拷贝:直接从内存映射读取节点 @@ -485,17 +490,19 @@ func NewBTreeReader(mmap mmap.MMap, rootOffset int64) *BTreeReader { // Get 查询 key,返回数据位置 // // 参数: -// key: 要查询的 key +// +// key: 要查询的 key // // 返回: -// dataOffset: 数据块的文件偏移量 -// dataSize: 数据块的大小 -// found: 是否找到 +// +// dataOffset: 数据块的文件偏移量 +// dataSize: 数据块的大小 +// found: 是否找到 // // 查询流程: -// 1. 从根节点开始遍历 -// 2. 内部节点:二分查找确定子节点,跳转 -// 3. 叶子节点:二分查找 key,返回数据位置 +// 1. 从根节点开始遍历 +// 2. 内部节点:二分查找确定子节点,跳转 +// 3. 叶子节点:二分查找 key,返回数据位置 func (r *BTreeReader) Get(key int64) (dataOffset int64, dataSize int32, found bool) { if r.rootOffset == 0 { return 0, 0, false @@ -543,7 +550,7 @@ func (r *BTreeReader) Get(key int64) (dataOffset int64, dataSize int32, found bo } } -// GetAllKeys 获取 B+Tree 中所有的 key(按顺序) +// GetAllKeys 获取 B+Tree 中所有的 key(按升序) func (r *BTreeReader) GetAllKeys() []int64 { if r.rootOffset == 0 { return nil @@ -562,7 +569,217 @@ func (r *BTreeReader) GetAllKeys() []int64 { return keys } -// traverseLeafNodes 遍历所有叶子节点 +// GetAllKeysDesc 获取 B+Tree 中所有的 key(按降序) +// +// 性能优化: +// - 从右到左遍历叶子节点 +// - 每个叶子节点内从后往前读取 keys +// - 避免额外的排序操作 +func (r *BTreeReader) GetAllKeysDesc() []int64 { + if r.rootOffset == 0 { + return nil + } + + var keys []int64 + r.traverseLeafNodesReverse(r.rootOffset, func(node *BTreeNode) { + // 从后往前添加 keys + for i := len(node.Keys) - 1; i >= 0; i-- { + keys = append(keys, node.Keys[i]) + } + }) + + return keys +} + +// KeyCallback 迭代回调函数 +// +// 参数: +// - key: 当前的 key(序列号) +// - dataOffset: 数据块的文件偏移量 +// - dataSize: 数据块的大小 +// +// 返回: +// - true: 继续迭代 +// - false: 停止迭代 +type KeyCallback func(key int64, dataOffset int64, dataSize int32) bool + +// ForEach 升序迭代所有 key(支持提前终止) +// +// 使用场景: +// - 需要遍历数据但不想一次性加载所有 keys(节省内存) +// - 支持条件过滤,找到目标后提前终止 +// - 支持外部自定义处理逻辑 +// +// 示例: +// +// // 找到第一个 > 100 的 key +// reader.ForEach(func(key int64, offset int64, size int32) bool { +// if key > 100 { +// fmt.Printf("Found: %d\n", key) +// return false // 停止迭代 +// } +// return true // 继续 +// }) +func (r *BTreeReader) ForEach(callback KeyCallback) { + if r.rootOffset == 0 { + return + } + + r.forEachInternal(r.rootOffset, callback, false) +} + +// ForEachDesc 降序迭代所有 key(支持提前终止) +// +// 使用场景: +// - 从最新数据开始遍历(时序数据库常见需求) +// - 查找最近的 N 条记录 +// - 支持条件过滤和提前终止 +// +// 示例: +// +// // 获取最新的 10 条记录 +// count := 0 +// reader.ForEachDesc(func(key int64, offset int64, size int32) bool { +// fmt.Printf("Key: %d\n", key) +// count++ +// return count < 10 // 找到 10 条后停止 +// }) +func (r *BTreeReader) ForEachDesc(callback KeyCallback) { + if r.rootOffset == 0 { + return + } + + r.forEachInternal(r.rootOffset, callback, true) +} + +// forEachInternal 内部迭代实现(支持升序和降序) +// +// 性能优化(真正的按需读取): +// - 只读取节点 header(32 bytes)确定节点类型和 key 数量 +// - 对于叶子节点,逐个读取 key、offset、size,避免一次性读取所有数据 +// - 对于内部节点,逐个读取 child offset,支持提前终止 +// - 如果回调在第 N 个 key 返回 false,只会读取前 N 个 key +// +// 参数: +// - nodeOffset: 当前节点的文件偏移量 +// - callback: 回调函数 +// - reverse: true=降序, false=升序 +// +// 返回: +// - true: 继续迭代 +// - false: 停止迭代(外部请求或遍历完成) +func (r *BTreeReader) forEachInternal(nodeOffset int64, callback KeyCallback, reverse bool) bool { + if nodeOffset+BTreeNodeSize > int64(len(r.mmap)) { + return true // 无效节点,继续其他分支 + } + + nodeData := r.mmap[nodeOffset : nodeOffset+BTreeNodeSize] + + // 只读取 header(32 bytes) + if len(nodeData) < BTreeHeaderSize { + return true + } + + nodeType := nodeData[0] + keyCount := int(binary.LittleEndian.Uint16(nodeData[1:3])) + + if nodeType == BTreeNodeTypeLeaf { + // 叶子节点:按需逐个读取 key 和 data + // 布局:[Header: 32B][Keys: keyCount*8B][Data: (offset,size) pairs] + + keysStartOffset := BTreeHeaderSize + dataStartOffset := keysStartOffset + keyCount*8 + + if reverse { + // 降序:从后往前读取 + for i := keyCount - 1; i >= 0; i-- { + // 读取 key + keyOffset := keysStartOffset + i*8 + if keyOffset+8 > len(nodeData) { + break + } + key := int64(binary.LittleEndian.Uint64(nodeData[keyOffset : keyOffset+8])) + + // 读取 dataOffset 和 dataSize(交错存储,每对 12 bytes) + dataOffset := dataStartOffset + i*12 + if dataOffset+12 > len(nodeData) { + break + } + offset := int64(binary.LittleEndian.Uint64(nodeData[dataOffset : dataOffset+8])) + size := int32(binary.LittleEndian.Uint32(nodeData[dataOffset+8 : dataOffset+12])) + + // 调用回调,如果返回 false 则立即停止(真正的按需读取) + if !callback(key, offset, size) { + return false + } + } + } else { + // 升序:从前往后读取 + for i := range keyCount { + // 读取 key + keyOffset := keysStartOffset + i*8 + if keyOffset+8 > len(nodeData) { + break + } + key := int64(binary.LittleEndian.Uint64(nodeData[keyOffset : keyOffset+8])) + + // 读取 dataOffset 和 dataSize + dataOffset := dataStartOffset + i*12 + if dataOffset+12 > len(nodeData) { + break + } + offset := int64(binary.LittleEndian.Uint64(nodeData[dataOffset : dataOffset+8])) + size := int32(binary.LittleEndian.Uint32(nodeData[dataOffset+8 : dataOffset+12])) + + // 调用回调,如果返回 false 则立即停止 + if !callback(key, offset, size) { + return false + } + } + } + return true + } + + // 内部节点:按需逐个读取 child offset + // 布局:[Header: 32B][Keys: keyCount*8B][Children: (keyCount+1)*8B] + + childCount := keyCount + 1 + childrenStartOffset := BTreeHeaderSize + keyCount*8 + + if reverse { + // 降序:从右到左遍历子节点 + for i := childCount - 1; i >= 0; i-- { + childOffset := childrenStartOffset + i*8 + if childOffset+8 > len(nodeData) { + break + } + childPtr := int64(binary.LittleEndian.Uint64(nodeData[childOffset : childOffset+8])) + + // 递归遍历子树,如果子树请求停止则立即返回 + if !r.forEachInternal(childPtr, callback, reverse) { + return false + } + } + } else { + // 升序:从左到右遍历子节点 + for i := range childCount { + childOffset := childrenStartOffset + i*8 + if childOffset+8 > len(nodeData) { + break + } + childPtr := int64(binary.LittleEndian.Uint64(nodeData[childOffset : childOffset+8])) + + // 递归遍历子树 + if !r.forEachInternal(childPtr, callback, reverse) { + return false + } + } + } + + return true +} + +// traverseLeafNodes 遍历所有叶子节点(从左到右) func (r *BTreeReader) traverseLeafNodes(nodeOffset int64, callback func(*BTreeNode)) { if nodeOffset+BTreeNodeSize > int64(len(r.mmap)) { return @@ -579,9 +796,37 @@ func (r *BTreeReader) traverseLeafNodes(nodeOffset int64, callback func(*BTreeNo // 叶子节点,执行回调 callback(node) } else { - // 内部节点,递归遍历所有子节点 + // 内部节点,递归遍历所有子节点(从左到右) for _, childOffset := range node.Children { r.traverseLeafNodes(childOffset, callback) } } } + +// traverseLeafNodesReverse 倒序遍历所有叶子节点(从右到左) +// +// 用于支持倒序查询,性能优化: +// - 避免先获取所有 keys 再反转 +// - 直接从最右侧的叶子节点开始遍历 +func (r *BTreeReader) traverseLeafNodesReverse(nodeOffset int64, callback func(*BTreeNode)) { + if nodeOffset+BTreeNodeSize > int64(len(r.mmap)) { + return + } + + nodeData := r.mmap[nodeOffset : nodeOffset+BTreeNodeSize] + node := UnmarshalBTree(nodeData) + + if node == nil { + return + } + + if node.NodeType == BTreeNodeTypeLeaf { + // 叶子节点,执行回调 + callback(node) + } else { + // 内部节点,递归遍历所有子节点(从右到左) + for i := len(node.Children) - 1; i >= 0; i-- { + r.traverseLeafNodesReverse(node.Children[i], callback) + } + } +} diff --git a/btree_test.go b/btree_test.go index 7f9b44d..587e483 100644 --- a/btree_test.go +++ b/btree_test.go @@ -132,6 +132,386 @@ func TestBTreeSerialization(t *testing.T) { t.Log("Serialization test passed!") } +// TestBTreeForEach 测试升序迭代 +func TestBTreeForEach(t *testing.T) { + // 创建测试文件 + file, err := os.Create("test_foreach.sst") + if err != nil { + t.Fatal(err) + } + defer os.Remove("test_foreach.sst") + + // 构建 B+Tree + builder := NewBTreeBuilder(file, 256) + for i := int64(1); i <= 100; i++ { + err := builder.Add(i, i*100, int32(i*10)) + if err != nil { + t.Fatal(err) + } + } + + rootOffset, err := builder.Build() + if err != nil { + t.Fatal(err) + } + file.Close() + + // 打开并 mmap + file, _ = os.Open("test_foreach.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + // 测试 1: 完整升序迭代 + t.Run("Complete", func(t *testing.T) { + var keys []int64 + var offsets []int64 + var sizes []int32 + + reader.ForEach(func(key int64, offset int64, size int32) bool { + keys = append(keys, key) + offsets = append(offsets, offset) + sizes = append(sizes, size) + return true + }) + + // 验证数量 + if len(keys) != 100 { + t.Errorf("Expected 100 keys, got %d", len(keys)) + } + + // 验证顺序(升序) + for i := 0; i < len(keys)-1; i++ { + if keys[i] >= keys[i+1] { + t.Errorf("Keys not in ascending order: keys[%d]=%d, keys[%d]=%d", + i, keys[i], i+1, keys[i+1]) + } + } + + // 验证第一个和最后一个 + if keys[0] != 1 { + t.Errorf("Expected first key=1, got %d", keys[0]) + } + if keys[99] != 100 { + t.Errorf("Expected last key=100, got %d", keys[99]) + } + + // 验证 offset 和 size + for i, key := range keys { + expectedOffset := key * 100 + expectedSize := int32(key * 10) + if offsets[i] != expectedOffset { + t.Errorf("Key %d: expected offset %d, got %d", key, expectedOffset, offsets[i]) + } + if sizes[i] != expectedSize { + t.Errorf("Key %d: expected size %d, got %d", key, expectedSize, sizes[i]) + } + } + }) + + // 测试 2: 提前终止 + t.Run("EarlyTermination", func(t *testing.T) { + var keys []int64 + reader.ForEach(func(key int64, offset int64, size int32) bool { + keys = append(keys, key) + return len(keys) < 5 // 只收集 5 个 + }) + + if len(keys) != 5 { + t.Errorf("Expected 5 keys, got %d", len(keys)) + } + if keys[0] != 1 || keys[4] != 5 { + t.Errorf("Expected keys [1,2,3,4,5], got %v", keys) + } + }) + + // 测试 3: 条件过滤 + t.Run("ConditionalFilter", func(t *testing.T) { + var evenKeys []int64 + reader.ForEach(func(key int64, offset int64, size int32) bool { + if key%2 == 0 { + evenKeys = append(evenKeys, key) + } + return true + }) + + if len(evenKeys) != 50 { + t.Errorf("Expected 50 even keys, got %d", len(evenKeys)) + } + + // 验证都是偶数 + for _, key := range evenKeys { + if key%2 != 0 { + t.Errorf("Key %d is not even", key) + } + } + }) + + // 测试 4: 查找第一个满足条件的 + t.Run("FindFirst", func(t *testing.T) { + var foundKey int64 + count := 0 + reader.ForEach(func(key int64, offset int64, size int32) bool { + count++ + if key > 50 { + foundKey = key + return false // 找到后停止 + } + return true + }) + + if foundKey != 51 { + t.Errorf("Expected to find key 51, got %d", foundKey) + } + if count != 51 { + t.Errorf("Expected to iterate 51 times, got %d", count) + } + }) + + // 测试 5: 与 GetAllKeys 结果一致性 + t.Run("ConsistencyWithGetAllKeys", func(t *testing.T) { + var iterKeys []int64 + reader.ForEach(func(key int64, offset int64, size int32) bool { + iterKeys = append(iterKeys, key) + return true + }) + + allKeys := reader.GetAllKeys() + + if len(iterKeys) != len(allKeys) { + t.Errorf("Length mismatch: ForEach=%d, GetAllKeys=%d", len(iterKeys), len(allKeys)) + } + + for i := range iterKeys { + if iterKeys[i] != allKeys[i] { + t.Errorf("Key mismatch at index %d: ForEach=%d, GetAllKeys=%d", + i, iterKeys[i], allKeys[i]) + } + } + }) +} + +// TestBTreeForEachDesc 测试降序迭代 +func TestBTreeForEachDesc(t *testing.T) { + // 创建测试文件 + file, err := os.Create("test_foreach_desc.sst") + if err != nil { + t.Fatal(err) + } + defer os.Remove("test_foreach_desc.sst") + + // 构建 B+Tree + builder := NewBTreeBuilder(file, 256) + for i := int64(1); i <= 100; i++ { + err := builder.Add(i, i*100, int32(i*10)) + if err != nil { + t.Fatal(err) + } + } + + rootOffset, err := builder.Build() + if err != nil { + t.Fatal(err) + } + file.Close() + + // 打开并 mmap + file, _ = os.Open("test_foreach_desc.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + // 测试 1: 完整降序迭代 + t.Run("Complete", func(t *testing.T) { + var keys []int64 + reader.ForEachDesc(func(key int64, offset int64, size int32) bool { + keys = append(keys, key) + return true + }) + + // 验证数量 + if len(keys) != 100 { + t.Errorf("Expected 100 keys, got %d", len(keys)) + } + + // 验证顺序(降序) + for i := 0; i < len(keys)-1; i++ { + if keys[i] <= keys[i+1] { + t.Errorf("Keys not in descending order: keys[%d]=%d, keys[%d]=%d", + i, keys[i], i+1, keys[i+1]) + } + } + + // 验证第一个和最后一个 + if keys[0] != 100 { + t.Errorf("Expected first key=100, got %d", keys[0]) + } + if keys[99] != 1 { + t.Errorf("Expected last key=1, got %d", keys[99]) + } + }) + + // 测试 2: 获取最新的 N 条记录(时序数据库常见需求) + t.Run("GetLatestN", func(t *testing.T) { + var latestKeys []int64 + reader.ForEachDesc(func(key int64, offset int64, size int32) bool { + latestKeys = append(latestKeys, key) + return len(latestKeys) < 10 // 只取最新的 10 条 + }) + + if len(latestKeys) != 10 { + t.Errorf("Expected 10 keys, got %d", len(latestKeys)) + } + + // 验证是最新的 10 条(100, 99, 98, ..., 91) + for i, key := range latestKeys { + expected := int64(100 - i) + if key != expected { + t.Errorf("latestKeys[%d]: expected %d, got %d", i, expected, key) + } + } + }) + + // 测试 3: 与 GetAllKeysDesc 结果一致性 + t.Run("ConsistencyWithGetAllKeysDesc", func(t *testing.T) { + var iterKeys []int64 + reader.ForEachDesc(func(key int64, offset int64, size int32) bool { + iterKeys = append(iterKeys, key) + return true + }) + + allKeys := reader.GetAllKeysDesc() + + if len(iterKeys) != len(allKeys) { + t.Errorf("Length mismatch: ForEachDesc=%d, GetAllKeysDesc=%d", len(iterKeys), len(allKeys)) + } + + for i := range iterKeys { + if iterKeys[i] != allKeys[i] { + t.Errorf("Key mismatch at index %d: ForEachDesc=%d, GetAllKeysDesc=%d", + i, iterKeys[i], allKeys[i]) + } + } + }) + + // 测试 4: 降序查找第一个满足条件的 + t.Run("FindFirstDesc", func(t *testing.T) { + var foundKey int64 + count := 0 + reader.ForEachDesc(func(key int64, offset int64, size int32) bool { + count++ + if key < 50 { + foundKey = key + return false // 找到后停止 + } + return true + }) + + if foundKey != 49 { + t.Errorf("Expected to find key 49, got %d", foundKey) + } + if count != 52 { // 100, 99, ..., 50, 49 + t.Errorf("Expected to iterate 52 times, got %d", count) + } + }) +} + +// TestBTreeForEachEmpty 测试空树的迭代 +func TestBTreeForEachEmpty(t *testing.T) { + // 创建空的 B+Tree + file, _ := os.Create("test_empty.sst") + defer os.Remove("test_empty.sst") + + builder := NewBTreeBuilder(file, 256) + rootOffset, _ := builder.Build() + file.Close() + + file, _ = os.Open("test_empty.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + // 测试升序迭代 + t.Run("ForEach", func(t *testing.T) { + called := false + reader.ForEach(func(key int64, offset int64, size int32) bool { + called = true + return true + }) + + if called { + t.Error("Callback should not be called on empty tree") + } + }) + + // 测试降序迭代 + t.Run("ForEachDesc", func(t *testing.T) { + called := false + reader.ForEachDesc(func(key int64, offset int64, size int32) bool { + called = true + return true + }) + + if called { + t.Error("Callback should not be called on empty tree") + } + }) +} + +// TestBTreeForEachSingle 测试单个元素的迭代 +func TestBTreeForEachSingle(t *testing.T) { + // 创建只有一个元素的 B+Tree + file, _ := os.Create("test_single.sst") + defer os.Remove("test_single.sst") + + builder := NewBTreeBuilder(file, 256) + builder.Add(42, 4200, 420) + rootOffset, _ := builder.Build() + file.Close() + + file, _ = os.Open("test_single.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + // 测试升序迭代 + t.Run("ForEach", func(t *testing.T) { + var keys []int64 + reader.ForEach(func(key int64, offset int64, size int32) bool { + keys = append(keys, key) + if offset != 4200 || size != 420 { + t.Errorf("Unexpected data: offset=%d, size=%d", offset, size) + } + return true + }) + + if len(keys) != 1 || keys[0] != 42 { + t.Errorf("Expected single key 42, got %v", keys) + } + }) + + // 测试降序迭代 + t.Run("ForEachDesc", func(t *testing.T) { + var keys []int64 + reader.ForEachDesc(func(key int64, offset int64, size int32) bool { + keys = append(keys, key) + return true + }) + + if len(keys) != 1 || keys[0] != 42 { + t.Errorf("Expected single key 42, got %v", keys) + } + }) +} + func BenchmarkBTreeGet(b *testing.B) { // 构建测试数据 file, _ := os.Create("bench.sst") @@ -159,3 +539,86 @@ func BenchmarkBTreeGet(b *testing.B) { reader.Get(key) } } + +// BenchmarkBTreeForEach 性能测试:完整迭代 +func BenchmarkBTreeForEach(b *testing.B) { + file, _ := os.Create("bench_foreach.sst") + defer os.Remove("bench_foreach.sst") + + builder := NewBTreeBuilder(file, 256) + for i := int64(1); i <= 10000; i++ { + builder.Add(i, i*100, 100) + } + rootOffset, _ := builder.Build() + file.Close() + + file, _ = os.Open("bench_foreach.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + b.ResetTimer() + for b.Loop() { + count := 0 + reader.ForEach(func(key int64, offset int64, size int32) bool { + count++ + return true + }) + } +} + +// BenchmarkBTreeForEachEarlyTermination 性能测试:提前终止 +func BenchmarkBTreeForEachEarlyTermination(b *testing.B) { + file, _ := os.Create("bench_foreach_early.sst") + defer os.Remove("bench_foreach_early.sst") + + builder := NewBTreeBuilder(file, 256) + for i := int64(1); i <= 100000; i++ { + builder.Add(i, i*100, 100) + } + rootOffset, _ := builder.Build() + file.Close() + + file, _ = os.Open("bench_foreach_early.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + b.ResetTimer() + for b.Loop() { + count := 0 + reader.ForEach(func(key int64, offset int64, size int32) bool { + count++ + return count < 10 // 只读取前 10 个 + }) + } +} + +// BenchmarkBTreeGetAllKeys vs ForEach 对比 +func BenchmarkBTreeGetAllKeys(b *testing.B) { + file, _ := os.Create("bench_getall.sst") + defer os.Remove("bench_getall.sst") + + builder := NewBTreeBuilder(file, 256) + for i := int64(1); i <= 10000; i++ { + builder.Add(i, i*100, 100) + } + rootOffset, _ := builder.Build() + file.Close() + + file, _ = os.Open("bench_getall.sst") + defer file.Close() + mmapData, _ := mmap.Map(file, mmap.RDONLY, 0) + defer mmapData.Unmap() + + reader := NewBTreeReader(mmapData, rootOffset) + + b.ResetTimer() + for b.Loop() { + _ = reader.GetAllKeys() + } +} diff --git a/index.go b/index.go index 5d414a8..0ca5daf 100644 --- a/index.go +++ b/index.go @@ -274,6 +274,46 @@ func (idx *SecondaryIndex) GetMetadata() IndexMetadata { return idx.metadata } +// ForEach 升序迭代所有索引条目 +// callback 返回 false 时停止迭代,支持提前终止 +// 注意:只能迭代已持久化的数据(B+Tree),不包括内存中未持久化的数据 +func (idx *SecondaryIndex) ForEach(callback IndexEntryCallback) error { + idx.mu.RLock() + defer idx.mu.RUnlock() + + if !idx.ready { + return fmt.Errorf("index not ready") + } + + // 只支持 B+Tree 格式的索引 + if !idx.useBTree || idx.btreeReader == nil { + return fmt.Errorf("ForEach only supports B+Tree format indexes") + } + + idx.btreeReader.ForEach(callback) + return nil +} + +// ForEachDesc 降序迭代所有索引条目 +// callback 返回 false 时停止迭代,支持提前终止 +// 注意:只能迭代已持久化的数据(B+Tree),不包括内存中未持久化的数据 +func (idx *SecondaryIndex) ForEachDesc(callback IndexEntryCallback) error { + idx.mu.RLock() + defer idx.mu.RUnlock() + + if !idx.ready { + return fmt.Errorf("index not ready") + } + + // 只支持 B+Tree 格式的索引 + if !idx.useBTree || idx.btreeReader == nil { + return fmt.Errorf("ForEachDesc only supports B+Tree format indexes") + } + + idx.btreeReader.ForEachDesc(callback) + return nil +} + // NeedsUpdate 检查是否需要更新 func (idx *SecondaryIndex) NeedsUpdate(currentMaxSeq int64) bool { idx.mu.RLock() diff --git a/index_btree.go b/index_btree.go index 3253f89..f2cceab 100644 --- a/index_btree.go +++ b/index_btree.go @@ -530,6 +530,55 @@ func (r *IndexBTreeReader) GetMetadata() IndexMetadata { } } +// IndexEntryCallback 索引条目回调函数 +// 参数:value 字段值,seqs 对应的 seq 列表 +// 返回:true 继续迭代,false 停止迭代 +type IndexEntryCallback func(value string, seqs []int64) bool + +// ForEach 升序迭代所有索引条目 +// callback 返回 false 时停止迭代,支持提前终止 +func (r *IndexBTreeReader) ForEach(callback IndexEntryCallback) { + r.btree.ForEach(func(key int64, dataOffset int64, dataSize int32) bool { + // 读取数据块(零拷贝) + if dataOffset+int64(dataSize) > int64(len(r.mmap)) { + return false // 数据越界,停止迭代 + } + + binaryData := r.mmap[dataOffset : dataOffset+int64(dataSize)] + + // 解码二进制数据 + value, seqs, err := decodeIndexEntry(binaryData) + if err != nil { + return false // 解码失败,停止迭代 + } + + // 调用用户回调 + return callback(value, seqs) + }) +} + +// ForEachDesc 降序迭代所有索引条目 +// callback 返回 false 时停止迭代,支持提前终止 +func (r *IndexBTreeReader) ForEachDesc(callback IndexEntryCallback) { + r.btree.ForEachDesc(func(key int64, dataOffset int64, dataSize int32) bool { + // 读取数据块(零拷贝) + if dataOffset+int64(dataSize) > int64(len(r.mmap)) { + return false // 数据越界,停止迭代 + } + + binaryData := r.mmap[dataOffset : dataOffset+int64(dataSize)] + + // 解码二进制数据 + value, seqs, err := decodeIndexEntry(binaryData) + if err != nil { + return false // 解码失败,停止迭代 + } + + // 调用用户回调 + return callback(value, seqs) + }) +} + // Close 关闭读取器 func (r *IndexBTreeReader) Close() error { if r.mmap != nil { diff --git a/query.go b/query.go index 668d3d3..9ed419f 100644 --- a/query.go +++ b/query.go @@ -5,6 +5,8 @@ import ( "fmt" "maps" "reflect" + "slices" + "sort" "strings" ) @@ -357,9 +359,13 @@ func Or(exprs ...Expr) Expr { } type QueryBuilder struct { - conds []Expr - fields []string // 要选择的字段,nil 表示选择所有字段 - table *Table + conds []Expr + fields []string // 要选择的字段,nil 表示选择所有字段 + table *Table + orderBy string // 排序字段,仅支持 "_seq" 或索引字段 + orderDesc bool // 是否降序排序 + offset int // 跳过的记录数 + limit int // 返回的最大记录数,0 表示无限制 } func newQueryBuilder(table *Table) *QueryBuilder { @@ -470,12 +476,123 @@ func (qb *QueryBuilder) NotNull(field string) *QueryBuilder { return qb.where(NotNull(field)) } +// OrderBy 设置排序字段(升序) +// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误 +func (qb *QueryBuilder) OrderBy(field string) *QueryBuilder { + qb.orderBy = field + qb.orderDesc = false + return qb +} + +// OrderByDesc 设置排序字段(降序) +// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误 +func (qb *QueryBuilder) OrderByDesc(field string) *QueryBuilder { + qb.orderBy = field + qb.orderDesc = true + return qb +} + +// Offset 设置跳过的记录数 +// 用于分页查询,跳过前 n 条记录 +func (qb *QueryBuilder) Offset(n int) *QueryBuilder { + if n < 0 { + n = 0 + } + qb.offset = n + return qb +} + +// Limit 设置返回的最大记录数 +// 用于分页查询,最多返回 n 条记录 +// n = 0 表示无限制 +func (qb *QueryBuilder) Limit(n int) *QueryBuilder { + if n < 0 { + n = 0 + } + qb.limit = n + return qb +} + +// Paginate 执行分页查询并返回结果、总记录数和错误 +// page: 页码,从 1 开始 +// pageSize: 每页记录数 +// 返回值: +// - rows: 当前页的数据 +// - total: 满足条件的总记录数(用于计算总页数) +// - err: 错误信息 +// +// 注意:此方法会执行两次查询,第一次获取总数,第二次获取分页数据 +func (qb *QueryBuilder) Paginate(page, pageSize int) (rows *Rows, total int, err error) { + if page < 1 { + page = 1 + } + if pageSize < 0 { + pageSize = 0 + } + + // 1. 先获取总记录数(不应用分页) + // 创建一个新的 QueryBuilder 副本用于计数 + countQb := &QueryBuilder{ + conds: qb.conds, + fields: qb.fields, + table: qb.table, + orderBy: "", // 计数不需要排序 + offset: 0, // 计数不应用分页 + limit: 0, + } + + countRows, err := countQb.Rows() + if err != nil { + return nil, 0, err + } + defer countRows.Close() + + // 计算总数 + total = countRows.Len() + + // 2. 执行分页查询 + qb.offset = (page - 1) * pageSize + qb.limit = pageSize + + rows, err = qb.Rows() + if err != nil { + return nil, total, err + } + + return rows, total, nil +} + +// validateOrderBy 验证排序字段是否有效 +func (qb *QueryBuilder) validateOrderBy() error { + if qb.orderBy == "" { + return nil // 没有设置排序,无需验证 + } + + // 允许使用 _seq + if qb.orderBy == "_seq" { + return nil + } + + // 检查该字段是否有索引 + if _, exists := qb.table.indexManager.GetIndex(qb.orderBy); exists { + return nil + } + + // 不支持的字段 + return fmt.Errorf("OrderBy only supports '_seq' or indexed fields, field '%s' is not indexed", qb.orderBy) +} + // Rows 返回所有匹配的数据(游标模式 - 惰性加载) func (qb *QueryBuilder) Rows() (*Rows, error) { if qb.table == nil { return nil, fmt.Errorf("table is nil") } + // 验证排序字段 + if err := qb.validateOrderBy(); err != nil { + return nil, err + } + rows := &Rows{ schema: qb.table.schema, fields: qb.fields, @@ -484,6 +601,11 @@ func (qb *QueryBuilder) Rows() (*Rows, error) { visited: make(map[int64]bool), } + // 如果设置了排序,使用排序后的结果集 + if qb.orderBy != "" { + return qb.rowsWithOrder(rows) + } + // 尝试使用索引优化查询 // 检查是否有可以使用索引的 Eq 条件 indexField, indexValue := qb.findIndexableCondition() @@ -570,6 +692,9 @@ func (qb *QueryBuilder) rowsWithIndex(rows *Rows, indexField string, indexValue } } + // 应用 offset 和 limit + rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows) + // 使用缓存模式 rows.cached = true rows.cachedIndex = -1 @@ -577,6 +702,194 @@ func (qb *QueryBuilder) rowsWithIndex(rows *Rows, indexField string, indexValue return rows, nil } +// rowsWithOrder 使用排序返回数据 +func (qb *QueryBuilder) rowsWithOrder(rows *Rows) (*Rows, error) { + if qb.orderBy == "_seq" { + // 按 _seq 排序 + return qb.rowsOrderBySeq(rows) + } + + // 按索引字段排序 + return qb.rowsOrderByIndex(rows, qb.orderBy) +} + +// rowsOrderBySeq 按 _seq 排序返回数据 +func (qb *QueryBuilder) rowsOrderBySeq(rows *Rows) (*Rows, error) { + // 收集所有 seq(从所有数据源) + seqList := []int64{} + + // 1. 从 Active MemTable 收集 + activeMemTable := qb.table.memtableManager.GetActive() + if activeMemTable != nil { + seqList = append(seqList, activeMemTable.Keys()...) + } + + // 2. 从 Immutable MemTables 收集 + immutables := qb.table.memtableManager.GetImmutables() + for _, immutable := range immutables { + seqList = append(seqList, immutable.MemTable.Keys()...) + } + + // 3. 从 SST 文件收集 + sstReaders := qb.table.sstManager.GetReaders() + for _, reader := range sstReaders { + seqList = append(seqList, reader.GetAllKeys()...) + } + + // 去重(使用 map) + seqMap := make(map[int64]bool) + uniqueSeqs := []int64{} + for _, seq := range seqList { + if !seqMap[seq] { + seqMap[seq] = true + uniqueSeqs = append(uniqueSeqs, seq) + } + } + + // 排序 + if qb.orderDesc { + // 降序 + sort.Slice(uniqueSeqs, func(i, j int) bool { + return uniqueSeqs[i] > uniqueSeqs[j] + }) + } else { + // 升序 + slices.Sort(uniqueSeqs) + } + + // 按排序后的 seq 获取数据 + rows.cachedRows = make([]*SSTableRow, 0, len(uniqueSeqs)) + for _, seq := range uniqueSeqs { + row, err := qb.table.Get(seq) + if err != nil { + continue // 跳过获取失败的记录 + } + + // 检查是否匹配过滤条件 + if qb.Match(row.Data) { + rows.cachedRows = append(rows.cachedRows, row) + } + } + + // 应用 offset 和 limit + rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows) + + // 使用缓存模式 + rows.cached = true + rows.cachedIndex = -1 + + return rows, nil +} + +// rowsOrderByIndex 按索引字段排序返回数据 +// +// 实现策略: +// 1. 使用 ForEach/ForEachDesc 从索引收集所有 (value, seqs) 对 +// 2. 按字段值(而非哈希)对这些对进行排序 +// 3. 按排序后的顺序获取数据 +// +// 注意:虽然使用了索引,但需要在内存中排序所有索引条目。 +// 对于大量唯一值的字段,内存开销可能较大。 +func (qb *QueryBuilder) rowsOrderByIndex(rows *Rows, indexField string) (*Rows, error) { + // 获取索引 + idx, exists := qb.table.indexManager.GetIndex(indexField) + if !exists { + return nil, fmt.Errorf("index on field %s not found", indexField) + } + + // 检查索引是否准备就绪 + if !idx.IsReady() { + return nil, fmt.Errorf("index on field %s is not ready", indexField) + } + + // 用于收集索引条目的结构 + type indexEntry struct { + value string + seqs []int64 + } + + // 收集所有索引条目 + entries := []indexEntry{} + err := idx.ForEach(func(value string, seqs []int64) bool { + // 复制 seqs 避免引用问题 + seqsCopy := make([]int64, len(seqs)) + copy(seqsCopy, seqs) + entries = append(entries, indexEntry{ + value: value, + seqs: seqsCopy, + }) + return true + }) + if err != nil { + return nil, fmt.Errorf("failed to iterate index: %w", err) + } + + // 按字段值排序(而非哈希) + if qb.orderDesc { + // 降序 + sort.Slice(entries, func(i, j int) bool { + return entries[i].value > entries[j].value + }) + } else { + // 升序 + sort.Slice(entries, func(i, j int) bool { + return entries[i].value < entries[j].value + }) + } + + // 按排序后的顺序收集所有 seq + allSeqs := []int64{} + for _, entry := range entries { + allSeqs = append(allSeqs, entry.seqs...) + } + + // 根据 seq 列表获取数据 + rows.cachedRows = make([]*SSTableRow, 0, len(allSeqs)) + for _, seq := range allSeqs { + row, err := qb.table.Get(seq) + if err != nil { + continue // 跳过获取失败的记录 + } + + // 检查是否匹配所有其他条件 + if qb.Match(row.Data) { + rows.cachedRows = append(rows.cachedRows, row) + } + } + + // 应用 offset 和 limit + rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows) + + // 使用缓存模式 + rows.cached = true + rows.cachedIndex = -1 + + return rows, nil +} + +// applyOffsetLimit 应用 offset 和 limit 到结果集 +func (qb *QueryBuilder) applyOffsetLimit(rows []*SSTableRow) []*SSTableRow { + // 如果没有设置 offset 和 limit,直接返回 + if qb.offset == 0 && qb.limit == 0 { + return rows + } + + // 应用 offset + if qb.offset > 0 { + if qb.offset >= len(rows) { + return []*SSTableRow{} + } + rows = rows[qb.offset:] + } + + // 应用 limit + if qb.limit > 0 && qb.limit < len(rows) { + rows = rows[:qb.limit] + } + + return rows +} + // First 返回第一个匹配的数据 func (qb *QueryBuilder) First() (*Row, error) { rows, err := qb.Rows() @@ -697,6 +1010,10 @@ type Rows struct { cached bool cachedRows []*SSTableRow cachedIndex int // 缓存模式下的迭代位置 + + // 分页状态(惰性模式) + skippedCount int // 已跳过的记录数(用于 offset) + returnedCount int // 已返回的记录数(用于 limit) } // memtableIterator 包装 MemTable 的迭代器 @@ -843,8 +1160,21 @@ func (r *Rows) next() bool { continue } + // 应用 offset:跳过前 N 条记录 + if r.qb.offset > 0 && r.skippedCount < r.qb.offset { + r.skippedCount++ + r.visited[minSeq] = true + continue + } + + // 应用 limit:达到返回上限后停止 + if r.qb.limit > 0 && r.returnedCount >= r.qb.limit { + return false + } + // 找到匹配的记录 r.visited[minSeq] = true + r.returnedCount++ r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row} return true } @@ -927,7 +1257,7 @@ func (r *Rows) Data() []map[string]any { // - 如果目标是结构体/指针:只扫描第一行 func (r *Rows) Scan(value any) error { rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { + if rv.Kind() != reflect.Pointer { return fmt.Errorf("scan target must be a pointer") } diff --git a/sstable.go b/sstable.go index 3543090..5ffab3c 100644 --- a/sstable.go +++ b/sstable.go @@ -1140,6 +1140,18 @@ func (r *SSTableReader) GetAllKeys() []int64 { return r.btReader.GetAllKeys() } +// ForEach 升序迭代所有 key-offset-size 对 +// callback 返回 false 时停止迭代,支持提前终止 +func (r *SSTableReader) ForEach(callback KeyCallback) { + r.btReader.ForEach(callback) +} + +// ForEachDesc 降序迭代所有 key-offset-size 对 +// callback 返回 false 时停止迭代,支持提前终止 +func (r *SSTableReader) ForEachDesc(callback KeyCallback) { + r.btReader.ForEachDesc(callback) +} + // Close 关闭读取器 func (r *SSTableReader) Close() error { if r.mmap != nil { diff --git a/table.go b/table.go index f305e88..ac66404 100644 --- a/table.go +++ b/table.go @@ -410,14 +410,40 @@ func (t *Table) insertSingle(data map[string]any) error { return NewError(ErrCodeSchemaValidationFailed, err) } - // 2. 生成 _seq + // 2. 类型转换:将数据转换为 Schema 定义的类型 + // 这样可以确保写入时的类型与 Schema 一致(例如将 int64 转换为 time.Time) + convertedData := make(map[string]any, len(data)) + for key, value := range data { + // 跳过 nil 值 + if value == nil { + convertedData[key] = nil + continue + } + + // 获取字段定义 + field, err := t.schema.GetField(key) + if err != nil { + // 字段不在 Schema 中,保持原值 + convertedData[key] = value + continue + } + + // 使用 Schema 的类型转换 + converted, err := convertValue(value, field.Type) + if err != nil { + return NewErrorf(ErrCodeSchemaValidationFailed, "convert field %s: %v", key, err) + } + convertedData[key] = converted + } + + // 3. 生成 _seq seq := t.seq.Add(1) - // 3. 添加系统字段 + // 4. 添加系统字段 row := &SSTableRow{ Seq: seq, Time: time.Now().UnixNano(), - Data: data, + Data: convertedData, } // 3. 序列化(使用二进制格式,保留类型信息) @@ -952,6 +978,16 @@ func (t *Table) ListIndexes() []string { return t.indexManager.ListIndexes() } +// GetIndex 获取指定字段的索引 +func (t *Table) GetIndex(field string) (*SecondaryIndex, bool) { + return t.indexManager.GetIndex(field) +} + +// BuildIndexes 构建所有索引 +func (t *Table) BuildIndexes() error { + return t.indexManager.BuildAll() +} + // GetIndexMetadata 获取索引元数据 func (t *Table) GetIndexMetadata() map[string]IndexMetadata { return t.indexManager.GetIndexMetadata()