Files
srdb/query.go

1317 lines
30 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package srdb
import (
"encoding/json"
"fmt"
"maps"
"reflect"
"slices"
"sort"
"strings"
)
type Fieldset interface {
Get(key string) (field Field, value any, err error)
}
// mapFieldset 实现 Fieldset 接口,包装 map[string]any 和 Schema
type mapFieldset struct {
data map[string]any
schema *Schema
}
func newMapFieldset(data map[string]any, schema *Schema) *mapFieldset {
return &mapFieldset{
data: data,
schema: schema,
}
}
func (m *mapFieldset) Get(key string) (Field, any, error) {
value, exists := m.data[key]
if !exists {
return Field{}, nil, fmt.Errorf("field %s not found", key)
}
// 从 Schema 获取字段定义
field, err := m.schema.GetField(key)
if err != nil {
// 字段在 schema 中不存在,返回默认 Field
return Field{Name: key}, value, nil
}
return *field, value, nil
}
type Expr interface {
Match(fs Fieldset) bool
}
type Neginative struct {
expr Expr
}
func (n Neginative) Match(fs Fieldset) bool {
if n.expr == nil {
return true
}
return !n.expr.Match(fs)
}
func Not(expr Expr) Expr {
return Neginative{expr}
}
type compare struct {
field string
op string
right any
}
func (c compare) Match(fs Fieldset) bool {
_, value, err := fs.Get(c.field)
if err != nil {
// 字段不存在
return c.op == "IS NULL"
}
// 处理 NULL 检查
if c.op == "IS NULL" {
return value == nil
}
if c.op == "IS NOT NULL" {
return value != nil
}
// 如果值为 nil其他操作都返回 false
if value == nil {
return false
}
switch c.op {
case "=":
return compareEqual(value, c.right)
case "!=":
return !compareEqual(value, c.right)
case "<":
return compareLess(value, c.right)
case ">":
return compareGreater(value, c.right)
case "<=":
return compareLess(value, c.right) || compareEqual(value, c.right)
case ">=":
return compareGreater(value, c.right) || compareEqual(value, c.right)
case "IN":
if list, ok := c.right.([]any); ok {
for _, item := range list {
if compareEqual(value, item) {
return true
}
}
}
return false
case "NOT IN":
if list, ok := c.right.([]any); ok {
for _, item := range list {
if compareEqual(value, item) {
return false
}
}
return true
}
return false
case "BETWEEN":
if list, ok := c.right.([]any); ok && len(list) == 2 {
return (compareGreater(value, list[0]) || compareEqual(value, list[0])) &&
(compareLess(value, list[1]) || compareEqual(value, list[1]))
}
return false
case "NOT BETWEEN":
if list, ok := c.right.([]any); ok && len(list) == 2 {
return !((compareGreater(value, list[0]) || compareEqual(value, list[0])) &&
(compareLess(value, list[1]) || compareEqual(value, list[1])))
}
return false
case "CONTAINS":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return strings.Contains(str, pattern)
}
}
return false
case "NOT CONTAINS":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return !strings.Contains(str, pattern)
}
}
return false
case "STARTS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return strings.HasPrefix(str, pattern)
}
}
return false
case "NOT STARTS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return !strings.HasPrefix(str, pattern)
}
}
return false
case "ENDS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return strings.HasSuffix(str, pattern)
}
}
return false
case "NOT ENDS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return !strings.HasSuffix(str, pattern)
}
}
return false
}
return false
}
// compareEqual 比较两个值是否相等
func compareEqual(left, right any) bool {
// 处理数值类型的比较
leftNum, leftIsNum := toFloat64(left)
rightNum, rightIsNum := toFloat64(right)
if leftIsNum && rightIsNum {
return leftNum == rightNum
}
// 其他类型直接比较
return left == right
}
// compareLess 比较 left < right
func compareLess(left, right any) bool {
// 数值比较
leftNum, leftIsNum := toFloat64(left)
rightNum, rightIsNum := toFloat64(right)
if leftIsNum && rightIsNum {
return leftNum < rightNum
}
// 字符串比较
if leftStr, ok := left.(string); ok {
if rightStr, ok := right.(string); ok {
return leftStr < rightStr
}
}
return false
}
// compareGreater 比较 left > right
func compareGreater(left, right any) bool {
// 数值比较
leftNum, leftIsNum := toFloat64(left)
rightNum, rightIsNum := toFloat64(right)
if leftIsNum && rightIsNum {
return leftNum > rightNum
}
// 字符串比较
if leftStr, ok := left.(string); ok {
if rightStr, ok := right.(string); ok {
return leftStr > rightStr
}
}
return false
}
// toFloat64 尝试将值转换为 float64
func toFloat64(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
case int32:
return float64(val), true
case int16:
return float64(val), true
case int8:
return float64(val), true
case uint:
return float64(val), true
case uint64:
return float64(val), true
case uint32:
return float64(val), true
case uint16:
return float64(val), true
case uint8:
return float64(val), true
default:
return 0, false
}
}
func Eq(field string, value any) Expr {
return compare{field, "=", value}
}
func NotEq(field string, value any) Expr {
return compare{field, "!=", value}
}
func Lt(field string, value any) Expr {
return compare{field, "<", value}
}
func Gt(field string, value any) Expr {
return compare{field, ">", value}
}
func Lte(field string, value any) Expr {
return compare{field, "<=", value}
}
func Gte(field string, value any) Expr {
return compare{field, ">=", value}
}
func In(field string, values []any) Expr {
return compare{field, "IN", values}
}
func NotIn(field string, values []any) Expr {
return compare{field, "NOT IN", values}
}
func Between(field string, min, max any) Expr {
return compare{field, "BETWEEN", []any{min, max}}
}
func NotBetween(field string, min, max any) Expr {
return compare{field, "NOT BETWEEN", []any{min, max}}
}
func Contains(field string, pattern string) Expr {
return compare{field, "CONTAINS", pattern}
}
func NotContains(field string, pattern string) Expr {
return compare{field, "NOT CONTAINS", pattern}
}
func StartsWith(field string, prefix string) Expr {
return compare{field, "STARTS WITH", prefix}
}
func NotStartsWith(field string, prefix string) Expr {
return compare{field, "NOT STARTS WITH", prefix}
}
func EndsWith(field string, suffix string) Expr {
return compare{field, "ENDS WITH", suffix}
}
func NotEndsWith(field string, suffix string) Expr {
return compare{field, "NOT ENDS WITH", suffix}
}
func IsNull(field string) Expr {
return compare{field, "IS NULL", nil}
}
func NotNull(field string) Expr {
return compare{field, "IS NOT NULL", nil}
}
type group struct {
exprs []Expr
and bool
}
func (g group) Match(fs Fieldset) bool {
for _, expr := range g.exprs {
matched := expr.Match(fs)
if matched && !g.and {
return true
}
if !matched && g.and {
return false
}
}
return true
}
func And(exprs ...Expr) Expr {
return group{exprs, true}
}
func Or(exprs ...Expr) Expr {
return group{exprs, false}
}
type QueryBuilder struct {
conds []Expr
fields []string // 要选择的字段nil 表示选择所有字段
table *Table
orderBy string // 排序字段,仅支持 "_seq" 或索引字段
orderDesc bool // 是否降序排序
offset int // 跳过的记录数
limit int // 返回的最大记录数0 表示无限制
}
func newQueryBuilder(table *Table) *QueryBuilder {
return &QueryBuilder{
table: table,
}
}
func (qb *QueryBuilder) where(expr Expr) *QueryBuilder {
qb.conds = append(qb.conds, expr)
return qb
}
// Match 检查数据是否匹配所有条件
func (qb *QueryBuilder) Match(data map[string]any) bool {
if len(qb.conds) == 0 {
return true
}
fs := newMapFieldset(data, qb.table.schema)
for _, cond := range qb.conds {
if !cond.Match(fs) {
return false
}
}
return true
}
// Select 指定要选择的字段,如果不调用则返回所有字段
func (qb *QueryBuilder) Select(fields ...string) *QueryBuilder {
qb.fields = fields
return qb
}
func (qb *QueryBuilder) Where(exprs ...Expr) *QueryBuilder {
return qb.where(And(exprs...))
}
func (qb *QueryBuilder) Eq(field string, value any) *QueryBuilder {
return qb.where(Eq(field, value))
}
func (qb *QueryBuilder) NotEq(field string, value any) *QueryBuilder {
return qb.where(NotEq(field, value))
}
func (qb *QueryBuilder) Lt(field string, value any) *QueryBuilder {
return qb.where(Lt(field, value))
}
func (qb *QueryBuilder) Gt(field string, value any) *QueryBuilder {
return qb.where(Gt(field, value))
}
func (qb *QueryBuilder) Lte(field string, value any) *QueryBuilder {
return qb.where(Lte(field, value))
}
func (qb *QueryBuilder) Gte(field string, value any) *QueryBuilder {
return qb.where(Gte(field, value))
}
func (qb *QueryBuilder) In(field string, values []any) *QueryBuilder {
return qb.where(In(field, values))
}
func (qb *QueryBuilder) NotIn(field string, values []any) *QueryBuilder {
return qb.where(NotIn(field, values))
}
func (qb *QueryBuilder) Between(field string, start, end any) *QueryBuilder {
return qb.where(Between(field, start, end))
}
func (qb *QueryBuilder) NotBetween(field string, start, end any) *QueryBuilder {
return qb.where(Not(Between(field, start, end)))
}
func (qb *QueryBuilder) Contains(field string, pattern string) *QueryBuilder {
return qb.where(Contains(field, pattern))
}
func (qb *QueryBuilder) NotContains(field string, pattern string) *QueryBuilder {
return qb.where(NotContains(field, pattern))
}
func (qb *QueryBuilder) StartsWith(field string, pattern string) *QueryBuilder {
return qb.where(StartsWith(field, pattern))
}
func (qb *QueryBuilder) NotStartsWith(field string, pattern string) *QueryBuilder {
return qb.where(NotStartsWith(field, pattern))
}
func (qb *QueryBuilder) EndsWith(field string, pattern string) *QueryBuilder {
return qb.where(EndsWith(field, pattern))
}
func (qb *QueryBuilder) NotEndsWith(field string, pattern string) *QueryBuilder {
return qb.where(NotEndsWith(field, pattern))
}
func (qb *QueryBuilder) IsNull(field string) *QueryBuilder {
return qb.where(IsNull(field))
}
func (qb *QueryBuilder) NotNull(field string) *QueryBuilder {
return qb.where(NotNull(field))
}
// OrderBy 设置排序字段(升序)
// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误
func (qb *QueryBuilder) OrderBy(field string) *QueryBuilder {
qb.orderBy = field
qb.orderDesc = false
return qb
}
// OrderByDesc 设置排序字段(降序)
// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误
func (qb *QueryBuilder) OrderByDesc(field string) *QueryBuilder {
qb.orderBy = field
qb.orderDesc = true
return qb
}
// Offset 设置跳过的记录数
// 用于分页查询,跳过前 n 条记录
func (qb *QueryBuilder) Offset(n int) *QueryBuilder {
if n < 0 {
n = 0
}
qb.offset = n
return qb
}
// Limit 设置返回的最大记录数
// 用于分页查询,最多返回 n 条记录
// n = 0 表示无限制
func (qb *QueryBuilder) Limit(n int) *QueryBuilder {
if n < 0 {
n = 0
}
qb.limit = n
return qb
}
// Paginate 执行分页查询并返回结果、总记录数和错误
// page: 页码,从 1 开始
// pageSize: 每页记录数
// 返回值:
// - rows: 当前页的数据
// - total: 满足条件的总记录数(用于计算总页数)
// - err: 错误信息
//
// 注意:此方法会执行两次查询,第一次获取总数,第二次获取分页数据
func (qb *QueryBuilder) Paginate(page, pageSize int) (rows *Rows, total int, err error) {
if page < 1 {
page = 1
}
if pageSize < 0 {
pageSize = 0
}
// 1. 先获取总记录数(不应用分页)
// 创建一个新的 QueryBuilder 副本用于计数
countQb := &QueryBuilder{
conds: qb.conds,
fields: qb.fields,
table: qb.table,
orderBy: "", // 计数不需要排序
offset: 0, // 计数不应用分页
limit: 0,
}
countRows, err := countQb.Rows()
if err != nil {
return nil, 0, err
}
defer countRows.Close()
// 计算总数
total = countRows.Len()
// 2. 执行分页查询
qb.offset = (page - 1) * pageSize
qb.limit = pageSize
rows, err = qb.Rows()
if err != nil {
return nil, total, err
}
return rows, total, nil
}
// validateOrderBy 验证排序字段是否有效
func (qb *QueryBuilder) validateOrderBy() error {
if qb.orderBy == "" {
return nil // 没有设置排序,无需验证
}
// 允许使用 _seq
if qb.orderBy == "_seq" {
return nil
}
// 检查该字段是否有索引
if _, exists := qb.table.indexManager.GetIndex(qb.orderBy); exists {
return nil
}
// 不支持的字段
return fmt.Errorf("OrderBy only supports '_seq' or indexed fields, field '%s' is not indexed", qb.orderBy)
}
// Rows 返回所有匹配的数据(游标模式 - 惰性加载)
func (qb *QueryBuilder) Rows() (*Rows, error) {
if qb.table == nil {
return nil, fmt.Errorf("table is nil")
}
// 验证排序字段
if err := qb.validateOrderBy(); err != nil {
return nil, err
}
rows := &Rows{
schema: qb.table.schema,
fields: qb.fields,
qb: qb,
table: qb.table,
visited: make(map[int64]bool),
}
// 如果设置了排序,使用排序后的结果集
if qb.orderBy != "" {
return qb.rowsWithOrder(rows)
}
// 尝试使用索引优化查询
// 检查是否有可以使用索引的 Eq 条件
indexField, indexValue := qb.findIndexableCondition()
if indexField != "" {
// 使用索引查询(索引查询需要立即加载,因为需要从索引获取 seq 列表)
return qb.rowsWithIndex(rows, indexField, indexValue)
}
// 惰性加载:只初始化迭代器,不读取数据
// 1. 初始化 Active MemTable 迭代器
activeMemTable := qb.table.memtableManager.GetActive()
if activeMemTable != nil {
rows.memIterator = newMemtableIterator(activeMemTable.Keys())
}
// 2. 初始化 Immutable MemTables稍后在 Next() 中迭代)
rows.immutableIndex = 0
rows.immutableIterator = nil
// 3. 初始化 SST 文件迭代器
sstReaders := qb.table.sstManager.GetReaders()
rows.sstReaders = make([]*sstReader, len(sstReaders))
for i, reader := range sstReaders {
rows.sstReaders[i] = &sstReader{
keys: reader.GetAllKeys(),
index: 0,
}
}
rows.sstIndex = 0
// 不设置 cached让 Next() 使用惰性加载
rows.cached = false
return rows, nil
}
// findIndexableCondition 查找可以使用索引的条件Eq 操作)
func (qb *QueryBuilder) findIndexableCondition() (string, any) {
for _, cond := range qb.conds {
// 检查是否是 compare 类型且操作符是 "="
if cmp, ok := cond.(compare); ok && cmp.op == "=" {
// 检查该字段是否有索引
if idx, exists := qb.table.indexManager.GetIndex(cmp.field); exists && idx.IsReady() {
return cmp.field, cmp.right
}
}
}
return "", nil
}
// rowsWithIndex 使用索引查询数据
func (qb *QueryBuilder) rowsWithIndex(rows *Rows, indexField string, indexValue any) (*Rows, error) {
// 获取索引
idx, exists := qb.table.indexManager.GetIndex(indexField)
if !exists {
return nil, fmt.Errorf("index on field %s not found", indexField)
}
// 从索引获取 seq 列表
seqs, err := idx.Get(indexValue)
if err != nil {
return nil, fmt.Errorf("index lookup failed: %w", err)
}
// 如果没有结果,返回空结果集
if len(seqs) == 0 {
rows.cached = true
rows.cachedIndex = -1
rows.cachedRows = []*SSTableRow{}
return rows, nil
}
// 根据 seq 列表获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(seqs))
for _, seq := range seqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配所有其他条件(索引只能优化一个条件)
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// rowsWithOrder 使用排序返回数据
func (qb *QueryBuilder) rowsWithOrder(rows *Rows) (*Rows, error) {
if qb.orderBy == "_seq" {
// 按 _seq 排序
return qb.rowsOrderBySeq(rows)
}
// 按索引字段排序
return qb.rowsOrderByIndex(rows, qb.orderBy)
}
// rowsOrderBySeq 按 _seq 排序返回数据
func (qb *QueryBuilder) rowsOrderBySeq(rows *Rows) (*Rows, error) {
// 收集所有 seq从所有数据源
seqList := []int64{}
// 1. 从 Active MemTable 收集
activeMemTable := qb.table.memtableManager.GetActive()
if activeMemTable != nil {
seqList = append(seqList, activeMemTable.Keys()...)
}
// 2. 从 Immutable MemTables 收集
immutables := qb.table.memtableManager.GetImmutables()
for _, immutable := range immutables {
seqList = append(seqList, immutable.MemTable.Keys()...)
}
// 3. 从 SST 文件收集
sstReaders := qb.table.sstManager.GetReaders()
for _, reader := range sstReaders {
seqList = append(seqList, reader.GetAllKeys()...)
}
// 去重(使用 map
seqMap := make(map[int64]bool)
uniqueSeqs := []int64{}
for _, seq := range seqList {
if !seqMap[seq] {
seqMap[seq] = true
uniqueSeqs = append(uniqueSeqs, seq)
}
}
// 排序
if qb.orderDesc {
// 降序
sort.Slice(uniqueSeqs, func(i, j int) bool {
return uniqueSeqs[i] > uniqueSeqs[j]
})
} else {
// 升序
slices.Sort(uniqueSeqs)
}
// 按排序后的 seq 获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(uniqueSeqs))
for _, seq := range uniqueSeqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配过滤条件
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// rowsOrderByIndex 按索引字段排序返回数据
//
// 实现策略:
// 1. 使用 ForEach/ForEachDesc 从索引收集所有 (value, seqs) 对
// 2. 按字段值(而非哈希)对这些对进行排序
// 3. 按排序后的顺序获取数据
//
// 注意:虽然使用了索引,但需要在内存中排序所有索引条目。
// 对于大量唯一值的字段,内存开销可能较大。
func (qb *QueryBuilder) rowsOrderByIndex(rows *Rows, indexField string) (*Rows, error) {
// 获取索引
idx, exists := qb.table.indexManager.GetIndex(indexField)
if !exists {
return nil, fmt.Errorf("index on field %s not found", indexField)
}
// 检查索引是否准备就绪
if !idx.IsReady() {
return nil, fmt.Errorf("index on field %s is not ready", indexField)
}
// 用于收集索引条目的结构
type indexEntry struct {
value string
seqs []int64
}
// 收集所有索引条目
entries := []indexEntry{}
err := idx.ForEach(func(value string, seqs []int64) bool {
// 复制 seqs 避免引用问题
seqsCopy := make([]int64, len(seqs))
copy(seqsCopy, seqs)
entries = append(entries, indexEntry{
value: value,
seqs: seqsCopy,
})
return true
})
if err != nil {
return nil, fmt.Errorf("failed to iterate index: %w", err)
}
// 按字段值排序(而非哈希)
if qb.orderDesc {
// 降序
sort.Slice(entries, func(i, j int) bool {
return entries[i].value > entries[j].value
})
} else {
// 升序
sort.Slice(entries, func(i, j int) bool {
return entries[i].value < entries[j].value
})
}
// 按排序后的顺序收集所有 seq
allSeqs := []int64{}
for _, entry := range entries {
allSeqs = append(allSeqs, entry.seqs...)
}
// 根据 seq 列表获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(allSeqs))
for _, seq := range allSeqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配所有其他条件
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// applyOffsetLimit 应用 offset 和 limit 到结果集
func (qb *QueryBuilder) applyOffsetLimit(rows []*SSTableRow) []*SSTableRow {
// 如果没有设置 offset 和 limit直接返回
if qb.offset == 0 && qb.limit == 0 {
return rows
}
// 应用 offset
if qb.offset > 0 {
if qb.offset >= len(rows) {
return []*SSTableRow{}
}
rows = rows[qb.offset:]
}
// 应用 limit
if qb.limit > 0 && qb.limit < len(rows) {
rows = rows[:qb.limit]
}
return rows
}
// First 返回第一个匹配的数据
func (qb *QueryBuilder) First() (*Row, error) {
rows, err := qb.Rows()
if err != nil {
return nil, err
}
defer rows.Close()
return rows.First()
}
// Last 返回最后一个匹配的数据
func (qb *QueryBuilder) Last() (*Row, error) {
rows, err := qb.Rows()
if err != nil {
return nil, err
}
defer rows.Close()
return rows.Last()
}
// Scan 扫描结果到指定的变量
func (qb *QueryBuilder) Scan(value any) error {
rows, err := qb.Rows()
if err != nil {
return err
}
defer rows.Close()
return rows.Scan(value)
}
type Row struct {
schema *Schema
fields []string // 要选择的字段nil 表示选择所有字段
inner *SSTableRow
}
// Data 获取行数据(根据 Select 过滤字段)
func (r *Row) Data() map[string]any {
if r.inner == nil {
return nil
}
// 如果没有指定字段,返回所有数据(包括 _seq 和 _time
if len(r.fields) == 0 {
result := make(map[string]any)
result["_seq"] = r.inner.Seq
result["_time"] = r.inner.Time
maps.Copy(result, r.inner.Data)
return result
}
// 根据指定的字段过滤
result := make(map[string]any)
for _, field := range r.fields {
if field == "_seq" {
result["_seq"] = r.inner.Seq
} else if field == "_time" {
result["_time"] = r.inner.Time
} else if val, ok := r.inner.Data[field]; ok {
result[field] = val
}
}
return result
}
// Seq 获取行序列号
func (r *Row) Seq() int64 {
if r.inner == nil {
return 0
}
return r.inner.Seq
}
// Scan 扫描行数据到指定的变量
func (r *Row) Scan(value any) error {
if r.inner == nil {
return fmt.Errorf("row is nil")
}
// 使用 r.Data() 而不是 r.inner.Data这样会应用字段过滤
data, err := json.Marshal(r.Data())
if err != nil {
return fmt.Errorf("marshal row data: %w", err)
}
err = json.Unmarshal(data, value)
if err != nil {
return fmt.Errorf("unmarshal to target: %w", err)
}
return nil
}
// Rows 游标模式的结果集(惰性加载)
type Rows struct {
schema *Schema
fields []string // 要选择的字段nil 表示选择所有字段
qb *QueryBuilder
table *Table
// 迭代状态
currentRow *Row
err error
closed bool
visited map[int64]bool // 已访问的 seq用于去重
// 数据源迭代器
memIterator *memtableIterator
immutableIndex int
immutableIterator *memtableIterator
sstIndex int
sstReaders []*sstReader
// 缓存模式(用于 Collect/Data 等方法)
cached bool
cachedRows []*SSTableRow
cachedIndex int // 缓存模式下的迭代位置
// 分页状态(惰性模式)
skippedCount int // 已跳过的记录数(用于 offset
returnedCount int // 已返回的记录数(用于 limit
}
// memtableIterator 包装 MemTable 的迭代器
type memtableIterator struct {
keys []int64
index int
}
func newMemtableIterator(keys []int64) *memtableIterator {
return &memtableIterator{
keys: keys,
index: -1,
}
}
func (m *memtableIterator) next() (int64, bool) {
m.index++
if m.index >= len(m.keys) {
return 0, false
}
return m.keys[m.index], true
}
// peek 查看下一个 seq但不推进指针
func (m *memtableIterator) peek() int64 {
nextIndex := m.index + 1
if nextIndex >= len(m.keys) {
return -1
}
return m.keys[nextIndex]
}
// sstReader 包装 SST Reader 的迭代状态
type sstReader struct {
keys []int64 // 文件中实际存在的 key 列表(已排序)
index int // 当前迭代位置
}
// Next 移动到下一行,返回是否还有数据
func (r *Rows) Next() bool {
if r.closed {
return false
}
if r.err != nil {
return false
}
// 如果是缓存模式,使用缓存的数据
if r.cached {
return r.nextFromCache()
}
// 惰性模式:从数据源读取
return r.next()
}
// next 从数据源读取下一条匹配的记录(惰性加载的核心逻辑)
// 使用归并排序,从所有数据源中选择最小的 seq
func (r *Rows) next() bool {
for {
// 初始化 Immutable 迭代器(如果需要)
if r.immutableIterator == nil && r.immutableIndex < len(r.table.memtableManager.GetImmutables()) {
immutables := r.table.memtableManager.GetImmutables()
if r.immutableIndex < len(immutables) {
r.immutableIterator = newMemtableIterator(immutables[r.immutableIndex].MemTable.Keys())
}
}
// 收集所有数据源的下一个 seq使用 peek不推进指针
minSeq := int64(-1)
minSource := -1 // 0=mem, 1=immutable, 2+=sst
// 1. 检查 Active MemTable
if r.memIterator != nil {
if seq := r.memIterator.peek(); seq != -1 {
if minSeq == -1 || seq < minSeq {
minSeq = seq
minSource = 0
}
}
}
// 2. 检查 Immutable MemTables
if r.immutableIterator != nil {
if seq := r.immutableIterator.peek(); seq != -1 {
if minSeq == -1 || seq < minSeq {
minSeq = seq
minSource = 1
}
}
}
// 3. 检查所有 SST 文件
for i, sstReader := range r.sstReaders {
if sstReader.index < len(sstReader.keys) {
seq := sstReader.keys[sstReader.index]
if minSeq == -1 || seq < minSeq {
minSeq = seq
minSource = 2 + i
}
}
}
// 如果没有找到任何数据源,说明迭代结束
if minSource == -1 {
return false
}
// 从选定的数据源推进指针
switch minSource {
case 0: // Active MemTable
r.memIterator.next()
if r.memIterator.peek() == -1 {
r.memIterator = nil
}
case 1: // Immutable MemTable
r.immutableIterator.next()
if r.immutableIterator.peek() == -1 {
r.immutableIterator = nil
r.immutableIndex++
}
default: // SST 文件
sstIndex := minSource - 2
r.sstReaders[sstIndex].index++
}
// 如果该 seq 已访问过(去重),继续下一轮
if r.visited[minSeq] {
continue
}
// 获取并验证该记录
row, err := r.table.Get(minSeq)
if err != nil {
r.visited[minSeq] = true
continue
}
// 检查是否匹配过滤条件
if !r.qb.Match(row.Data) {
r.visited[minSeq] = true
continue
}
// 应用 offset跳过前 N 条记录
if r.qb.offset > 0 && r.skippedCount < r.qb.offset {
r.skippedCount++
r.visited[minSeq] = true
continue
}
// 应用 limit达到返回上限后停止
if r.qb.limit > 0 && r.returnedCount >= r.qb.limit {
return false
}
// 找到匹配的记录
r.visited[minSeq] = true
r.returnedCount++
r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row}
return true
}
}
// nextFromCache 从缓存中获取下一条记录
func (r *Rows) nextFromCache() bool {
r.cachedIndex++
if r.cachedIndex >= len(r.cachedRows) {
return false
}
r.currentRow = &Row{
schema: r.schema,
fields: r.fields,
inner: r.cachedRows[r.cachedIndex],
}
return true
}
// Row 获取当前行
func (r *Rows) Row() *Row {
return r.currentRow
}
// Err 返回错误
func (r *Rows) Err() error {
return r.err
}
// Close 关闭游标
func (r *Rows) Close() error {
r.closed = true
return nil
}
// ensureCached 确保所有数据已被加载到缓存
func (r *Rows) ensureCached() {
if r.cached {
return
}
// 使用私有的 next() 方法直接从数据源读取所有剩余数据
// 这样避免了与 Next() 的循环调用问题
// 注意:如果之前已经调用过 Next(),部分数据已经被消耗,只能缓存剩余数据
for r.next() {
if r.currentRow != nil && r.currentRow.inner != nil {
r.cachedRows = append(r.cachedRows, r.currentRow.inner)
}
}
// 标记为已缓存,重置迭代位置
r.cached = true
r.cachedIndex = -1
}
// Len 返回总行数(需要完全扫描)
func (r *Rows) Len() int {
r.ensureCached()
return len(r.cachedRows)
}
// Collect 收集所有结果到切片
func (r *Rows) Collect() []map[string]any {
r.ensureCached()
var results []map[string]any
for _, row := range r.cachedRows {
results = append(results, row.Data)
}
return results
}
// Data 获取所有行的数据(向后兼容)
func (r *Rows) Data() []map[string]any {
return r.Collect()
}
// Scan 扫描所有行数据到指定的变量
// 智能判断目标类型:
// - 如果目标是切片:扫描所有行
// - 如果目标是结构体/指针:只扫描第一行
func (r *Rows) Scan(value any) error {
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Pointer {
return fmt.Errorf("scan target must be a pointer")
}
elem := rv.Elem()
kind := elem.Kind()
// 如果目标是切片,扫描所有行
if kind == reflect.Slice {
data, err := json.Marshal(r.Collect())
if err != nil {
return fmt.Errorf("marshal rows data: %w", err)
}
err = json.Unmarshal(data, value)
if err != nil {
return fmt.Errorf("unmarshal to target: %w", err)
}
return nil
}
// 否则,只扫描第一行
row, err := r.First()
if err != nil {
return err
}
return row.Scan(value)
}
// First 获取第一行
func (r *Rows) First() (*Row, error) {
// 尝试获取第一条记录(不使用缓存)
if r.Next() {
return r.currentRow, nil
}
return nil, fmt.Errorf("no rows")
}
// Last 获取最后一行
func (r *Rows) Last() (*Row, error) {
r.ensureCached()
if len(r.cachedRows) == 0 {
return nil, fmt.Errorf("no rows")
}
return &Row{
schema: r.schema,
fields: r.fields,
inner: r.cachedRows[len(r.cachedRows)-1],
}, nil
}
// Count 返回总行数(别名)
func (r *Rows) Count() int {
return r.Len()
}