feat: updates to btree/index/query/sstable/table

This commit is contained in:
2025-10-11 13:19:26 +08:00
parent c8cbe4178f
commit 03ec262ca5
7 changed files with 1229 additions and 54 deletions

338
query.go
View File

@@ -5,6 +5,8 @@ import (
"fmt"
"maps"
"reflect"
"slices"
"sort"
"strings"
)
@@ -357,9 +359,13 @@ func Or(exprs ...Expr) Expr {
}
type QueryBuilder struct {
conds []Expr
fields []string // 要选择的字段nil 表示选择所有字段
table *Table
conds []Expr
fields []string // 要选择的字段nil 表示选择所有字段
table *Table
orderBy string // 排序字段,仅支持 "_seq" 或索引字段
orderDesc bool // 是否降序排序
offset int // 跳过的记录数
limit int // 返回的最大记录数0 表示无限制
}
func newQueryBuilder(table *Table) *QueryBuilder {
@@ -470,12 +476,123 @@ func (qb *QueryBuilder) NotNull(field string) *QueryBuilder {
return qb.where(NotNull(field))
}
// OrderBy 设置排序字段(升序)
// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误
func (qb *QueryBuilder) OrderBy(field string) *QueryBuilder {
qb.orderBy = field
qb.orderDesc = false
return qb
}
// OrderByDesc 设置排序字段(降序)
// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误
func (qb *QueryBuilder) OrderByDesc(field string) *QueryBuilder {
qb.orderBy = field
qb.orderDesc = true
return qb
}
// Offset 设置跳过的记录数
// 用于分页查询,跳过前 n 条记录
func (qb *QueryBuilder) Offset(n int) *QueryBuilder {
if n < 0 {
n = 0
}
qb.offset = n
return qb
}
// Limit 设置返回的最大记录数
// 用于分页查询,最多返回 n 条记录
// n = 0 表示无限制
func (qb *QueryBuilder) Limit(n int) *QueryBuilder {
if n < 0 {
n = 0
}
qb.limit = n
return qb
}
// Paginate 执行分页查询并返回结果、总记录数和错误
// page: 页码,从 1 开始
// pageSize: 每页记录数
// 返回值:
// - rows: 当前页的数据
// - total: 满足条件的总记录数(用于计算总页数)
// - err: 错误信息
//
// 注意:此方法会执行两次查询,第一次获取总数,第二次获取分页数据
func (qb *QueryBuilder) Paginate(page, pageSize int) (rows *Rows, total int, err error) {
if page < 1 {
page = 1
}
if pageSize < 0 {
pageSize = 0
}
// 1. 先获取总记录数(不应用分页)
// 创建一个新的 QueryBuilder 副本用于计数
countQb := &QueryBuilder{
conds: qb.conds,
fields: qb.fields,
table: qb.table,
orderBy: "", // 计数不需要排序
offset: 0, // 计数不应用分页
limit: 0,
}
countRows, err := countQb.Rows()
if err != nil {
return nil, 0, err
}
defer countRows.Close()
// 计算总数
total = countRows.Len()
// 2. 执行分页查询
qb.offset = (page - 1) * pageSize
qb.limit = pageSize
rows, err = qb.Rows()
if err != nil {
return nil, total, err
}
return rows, total, nil
}
// validateOrderBy 验证排序字段是否有效
func (qb *QueryBuilder) validateOrderBy() error {
if qb.orderBy == "" {
return nil // 没有设置排序,无需验证
}
// 允许使用 _seq
if qb.orderBy == "_seq" {
return nil
}
// 检查该字段是否有索引
if _, exists := qb.table.indexManager.GetIndex(qb.orderBy); exists {
return nil
}
// 不支持的字段
return fmt.Errorf("OrderBy only supports '_seq' or indexed fields, field '%s' is not indexed", qb.orderBy)
}
// Rows 返回所有匹配的数据(游标模式 - 惰性加载)
func (qb *QueryBuilder) Rows() (*Rows, error) {
if qb.table == nil {
return nil, fmt.Errorf("table is nil")
}
// 验证排序字段
if err := qb.validateOrderBy(); err != nil {
return nil, err
}
rows := &Rows{
schema: qb.table.schema,
fields: qb.fields,
@@ -484,6 +601,11 @@ func (qb *QueryBuilder) Rows() (*Rows, error) {
visited: make(map[int64]bool),
}
// 如果设置了排序,使用排序后的结果集
if qb.orderBy != "" {
return qb.rowsWithOrder(rows)
}
// 尝试使用索引优化查询
// 检查是否有可以使用索引的 Eq 条件
indexField, indexValue := qb.findIndexableCondition()
@@ -570,6 +692,9 @@ func (qb *QueryBuilder) rowsWithIndex(rows *Rows, indexField string, indexValue
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
@@ -577,6 +702,194 @@ func (qb *QueryBuilder) rowsWithIndex(rows *Rows, indexField string, indexValue
return rows, nil
}
// rowsWithOrder 使用排序返回数据
func (qb *QueryBuilder) rowsWithOrder(rows *Rows) (*Rows, error) {
if qb.orderBy == "_seq" {
// 按 _seq 排序
return qb.rowsOrderBySeq(rows)
}
// 按索引字段排序
return qb.rowsOrderByIndex(rows, qb.orderBy)
}
// rowsOrderBySeq 按 _seq 排序返回数据
func (qb *QueryBuilder) rowsOrderBySeq(rows *Rows) (*Rows, error) {
// 收集所有 seq从所有数据源
seqList := []int64{}
// 1. 从 Active MemTable 收集
activeMemTable := qb.table.memtableManager.GetActive()
if activeMemTable != nil {
seqList = append(seqList, activeMemTable.Keys()...)
}
// 2. 从 Immutable MemTables 收集
immutables := qb.table.memtableManager.GetImmutables()
for _, immutable := range immutables {
seqList = append(seqList, immutable.MemTable.Keys()...)
}
// 3. 从 SST 文件收集
sstReaders := qb.table.sstManager.GetReaders()
for _, reader := range sstReaders {
seqList = append(seqList, reader.GetAllKeys()...)
}
// 去重(使用 map
seqMap := make(map[int64]bool)
uniqueSeqs := []int64{}
for _, seq := range seqList {
if !seqMap[seq] {
seqMap[seq] = true
uniqueSeqs = append(uniqueSeqs, seq)
}
}
// 排序
if qb.orderDesc {
// 降序
sort.Slice(uniqueSeqs, func(i, j int) bool {
return uniqueSeqs[i] > uniqueSeqs[j]
})
} else {
// 升序
slices.Sort(uniqueSeqs)
}
// 按排序后的 seq 获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(uniqueSeqs))
for _, seq := range uniqueSeqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配过滤条件
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// rowsOrderByIndex 按索引字段排序返回数据
//
// 实现策略:
// 1. 使用 ForEach/ForEachDesc 从索引收集所有 (value, seqs) 对
// 2. 按字段值(而非哈希)对这些对进行排序
// 3. 按排序后的顺序获取数据
//
// 注意:虽然使用了索引,但需要在内存中排序所有索引条目。
// 对于大量唯一值的字段,内存开销可能较大。
func (qb *QueryBuilder) rowsOrderByIndex(rows *Rows, indexField string) (*Rows, error) {
// 获取索引
idx, exists := qb.table.indexManager.GetIndex(indexField)
if !exists {
return nil, fmt.Errorf("index on field %s not found", indexField)
}
// 检查索引是否准备就绪
if !idx.IsReady() {
return nil, fmt.Errorf("index on field %s is not ready", indexField)
}
// 用于收集索引条目的结构
type indexEntry struct {
value string
seqs []int64
}
// 收集所有索引条目
entries := []indexEntry{}
err := idx.ForEach(func(value string, seqs []int64) bool {
// 复制 seqs 避免引用问题
seqsCopy := make([]int64, len(seqs))
copy(seqsCopy, seqs)
entries = append(entries, indexEntry{
value: value,
seqs: seqsCopy,
})
return true
})
if err != nil {
return nil, fmt.Errorf("failed to iterate index: %w", err)
}
// 按字段值排序(而非哈希)
if qb.orderDesc {
// 降序
sort.Slice(entries, func(i, j int) bool {
return entries[i].value > entries[j].value
})
} else {
// 升序
sort.Slice(entries, func(i, j int) bool {
return entries[i].value < entries[j].value
})
}
// 按排序后的顺序收集所有 seq
allSeqs := []int64{}
for _, entry := range entries {
allSeqs = append(allSeqs, entry.seqs...)
}
// 根据 seq 列表获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(allSeqs))
for _, seq := range allSeqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配所有其他条件
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// applyOffsetLimit 应用 offset 和 limit 到结果集
func (qb *QueryBuilder) applyOffsetLimit(rows []*SSTableRow) []*SSTableRow {
// 如果没有设置 offset 和 limit直接返回
if qb.offset == 0 && qb.limit == 0 {
return rows
}
// 应用 offset
if qb.offset > 0 {
if qb.offset >= len(rows) {
return []*SSTableRow{}
}
rows = rows[qb.offset:]
}
// 应用 limit
if qb.limit > 0 && qb.limit < len(rows) {
rows = rows[:qb.limit]
}
return rows
}
// First 返回第一个匹配的数据
func (qb *QueryBuilder) First() (*Row, error) {
rows, err := qb.Rows()
@@ -697,6 +1010,10 @@ type Rows struct {
cached bool
cachedRows []*SSTableRow
cachedIndex int // 缓存模式下的迭代位置
// 分页状态(惰性模式)
skippedCount int // 已跳过的记录数(用于 offset
returnedCount int // 已返回的记录数(用于 limit
}
// memtableIterator 包装 MemTable 的迭代器
@@ -843,8 +1160,21 @@ func (r *Rows) next() bool {
continue
}
// 应用 offset跳过前 N 条记录
if r.qb.offset > 0 && r.skippedCount < r.qb.offset {
r.skippedCount++
r.visited[minSeq] = true
continue
}
// 应用 limit达到返回上限后停止
if r.qb.limit > 0 && r.returnedCount >= r.qb.limit {
return false
}
// 找到匹配的记录
r.visited[minSeq] = true
r.returnedCount++
r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row}
return true
}
@@ -927,7 +1257,7 @@ func (r *Rows) Data() []map[string]any {
// - 如果目标是结构体/指针:只扫描第一行
func (r *Rows) Scan(value any) error {
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
if rv.Kind() != reflect.Pointer {
return fmt.Errorf("scan target must be a pointer")
}