Files
srdb/query.go

1317 lines
30 KiB
Go
Raw Permalink Normal View History

package srdb
import (
"encoding/json"
"fmt"
"maps"
"reflect"
"slices"
"sort"
"strings"
)
type Fieldset interface {
Get(key string) (field Field, value any, err error)
}
// mapFieldset 实现 Fieldset 接口,包装 map[string]any 和 Schema
type mapFieldset struct {
data map[string]any
schema *Schema
}
func newMapFieldset(data map[string]any, schema *Schema) *mapFieldset {
return &mapFieldset{
data: data,
schema: schema,
}
}
func (m *mapFieldset) Get(key string) (Field, any, error) {
value, exists := m.data[key]
if !exists {
return Field{}, nil, fmt.Errorf("field %s not found", key)
}
// 从 Schema 获取字段定义
field, err := m.schema.GetField(key)
if err != nil {
// 字段在 schema 中不存在,返回默认 Field
return Field{Name: key}, value, nil
}
return *field, value, nil
}
type Expr interface {
Match(fs Fieldset) bool
}
type Neginative struct {
expr Expr
}
func (n Neginative) Match(fs Fieldset) bool {
if n.expr == nil {
return true
}
return !n.expr.Match(fs)
}
func Not(expr Expr) Expr {
return Neginative{expr}
}
type compare struct {
field string
op string
right any
}
func (c compare) Match(fs Fieldset) bool {
_, value, err := fs.Get(c.field)
if err != nil {
// 字段不存在
return c.op == "IS NULL"
}
// 处理 NULL 检查
if c.op == "IS NULL" {
return value == nil
}
if c.op == "IS NOT NULL" {
return value != nil
}
// 如果值为 nil其他操作都返回 false
if value == nil {
return false
}
switch c.op {
case "=":
return compareEqual(value, c.right)
case "!=":
return !compareEqual(value, c.right)
case "<":
return compareLess(value, c.right)
case ">":
return compareGreater(value, c.right)
case "<=":
return compareLess(value, c.right) || compareEqual(value, c.right)
case ">=":
return compareGreater(value, c.right) || compareEqual(value, c.right)
case "IN":
if list, ok := c.right.([]any); ok {
for _, item := range list {
if compareEqual(value, item) {
return true
}
}
}
return false
case "NOT IN":
if list, ok := c.right.([]any); ok {
for _, item := range list {
if compareEqual(value, item) {
return false
}
}
return true
}
return false
case "BETWEEN":
if list, ok := c.right.([]any); ok && len(list) == 2 {
return (compareGreater(value, list[0]) || compareEqual(value, list[0])) &&
(compareLess(value, list[1]) || compareEqual(value, list[1]))
}
return false
case "NOT BETWEEN":
if list, ok := c.right.([]any); ok && len(list) == 2 {
return !((compareGreater(value, list[0]) || compareEqual(value, list[0])) &&
(compareLess(value, list[1]) || compareEqual(value, list[1])))
}
return false
case "CONTAINS":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return strings.Contains(str, pattern)
}
}
return false
case "NOT CONTAINS":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return !strings.Contains(str, pattern)
}
}
return false
case "STARTS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return strings.HasPrefix(str, pattern)
}
}
return false
case "NOT STARTS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return !strings.HasPrefix(str, pattern)
}
}
return false
case "ENDS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return strings.HasSuffix(str, pattern)
}
}
return false
case "NOT ENDS WITH":
if str, ok := value.(string); ok {
if pattern, ok := c.right.(string); ok {
return !strings.HasSuffix(str, pattern)
}
}
return false
}
return false
}
// compareEqual 比较两个值是否相等
func compareEqual(left, right any) bool {
// 处理数值类型的比较
leftNum, leftIsNum := toFloat64(left)
rightNum, rightIsNum := toFloat64(right)
if leftIsNum && rightIsNum {
return leftNum == rightNum
}
// 其他类型直接比较
return left == right
}
// compareLess 比较 left < right
func compareLess(left, right any) bool {
// 数值比较
leftNum, leftIsNum := toFloat64(left)
rightNum, rightIsNum := toFloat64(right)
if leftIsNum && rightIsNum {
return leftNum < rightNum
}
// 字符串比较
if leftStr, ok := left.(string); ok {
if rightStr, ok := right.(string); ok {
return leftStr < rightStr
}
}
return false
}
// compareGreater 比较 left > right
func compareGreater(left, right any) bool {
// 数值比较
leftNum, leftIsNum := toFloat64(left)
rightNum, rightIsNum := toFloat64(right)
if leftIsNum && rightIsNum {
return leftNum > rightNum
}
// 字符串比较
if leftStr, ok := left.(string); ok {
if rightStr, ok := right.(string); ok {
return leftStr > rightStr
}
}
return false
}
// toFloat64 尝试将值转换为 float64
func toFloat64(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
case int32:
return float64(val), true
case int16:
return float64(val), true
case int8:
return float64(val), true
case uint:
return float64(val), true
case uint64:
return float64(val), true
case uint32:
return float64(val), true
case uint16:
return float64(val), true
case uint8:
return float64(val), true
default:
return 0, false
}
}
func Eq(field string, value any) Expr {
return compare{field, "=", value}
}
func NotEq(field string, value any) Expr {
return compare{field, "!=", value}
}
func Lt(field string, value any) Expr {
return compare{field, "<", value}
}
func Gt(field string, value any) Expr {
return compare{field, ">", value}
}
func Lte(field string, value any) Expr {
return compare{field, "<=", value}
}
func Gte(field string, value any) Expr {
return compare{field, ">=", value}
}
func In(field string, values []any) Expr {
return compare{field, "IN", values}
}
func NotIn(field string, values []any) Expr {
return compare{field, "NOT IN", values}
}
func Between(field string, min, max any) Expr {
return compare{field, "BETWEEN", []any{min, max}}
}
func NotBetween(field string, min, max any) Expr {
return compare{field, "NOT BETWEEN", []any{min, max}}
}
func Contains(field string, pattern string) Expr {
return compare{field, "CONTAINS", pattern}
}
func NotContains(field string, pattern string) Expr {
return compare{field, "NOT CONTAINS", pattern}
}
func StartsWith(field string, prefix string) Expr {
return compare{field, "STARTS WITH", prefix}
}
func NotStartsWith(field string, prefix string) Expr {
return compare{field, "NOT STARTS WITH", prefix}
}
func EndsWith(field string, suffix string) Expr {
return compare{field, "ENDS WITH", suffix}
}
func NotEndsWith(field string, suffix string) Expr {
return compare{field, "NOT ENDS WITH", suffix}
}
func IsNull(field string) Expr {
return compare{field, "IS NULL", nil}
}
func NotNull(field string) Expr {
return compare{field, "IS NOT NULL", nil}
}
type group struct {
exprs []Expr
and bool
}
func (g group) Match(fs Fieldset) bool {
for _, expr := range g.exprs {
matched := expr.Match(fs)
if matched && !g.and {
return true
}
if !matched && g.and {
return false
}
}
return true
}
func And(exprs ...Expr) Expr {
return group{exprs, true}
}
func Or(exprs ...Expr) Expr {
return group{exprs, false}
}
type QueryBuilder struct {
conds []Expr
fields []string // 要选择的字段nil 表示选择所有字段
table *Table
orderBy string // 排序字段,仅支持 "_seq" 或索引字段
orderDesc bool // 是否降序排序
offset int // 跳过的记录数
limit int // 返回的最大记录数0 表示无限制
}
func newQueryBuilder(table *Table) *QueryBuilder {
return &QueryBuilder{
table: table,
}
}
func (qb *QueryBuilder) where(expr Expr) *QueryBuilder {
qb.conds = append(qb.conds, expr)
return qb
}
// Match 检查数据是否匹配所有条件
func (qb *QueryBuilder) Match(data map[string]any) bool {
if len(qb.conds) == 0 {
return true
}
fs := newMapFieldset(data, qb.table.schema)
for _, cond := range qb.conds {
if !cond.Match(fs) {
return false
}
}
return true
}
// Select 指定要选择的字段,如果不调用则返回所有字段
func (qb *QueryBuilder) Select(fields ...string) *QueryBuilder {
qb.fields = fields
return qb
}
func (qb *QueryBuilder) Where(exprs ...Expr) *QueryBuilder {
return qb.where(And(exprs...))
}
func (qb *QueryBuilder) Eq(field string, value any) *QueryBuilder {
return qb.where(Eq(field, value))
}
func (qb *QueryBuilder) NotEq(field string, value any) *QueryBuilder {
return qb.where(NotEq(field, value))
}
func (qb *QueryBuilder) Lt(field string, value any) *QueryBuilder {
return qb.where(Lt(field, value))
}
func (qb *QueryBuilder) Gt(field string, value any) *QueryBuilder {
return qb.where(Gt(field, value))
}
func (qb *QueryBuilder) Lte(field string, value any) *QueryBuilder {
return qb.where(Lte(field, value))
}
func (qb *QueryBuilder) Gte(field string, value any) *QueryBuilder {
return qb.where(Gte(field, value))
}
func (qb *QueryBuilder) In(field string, values []any) *QueryBuilder {
return qb.where(In(field, values))
}
func (qb *QueryBuilder) NotIn(field string, values []any) *QueryBuilder {
return qb.where(NotIn(field, values))
}
func (qb *QueryBuilder) Between(field string, start, end any) *QueryBuilder {
return qb.where(Between(field, start, end))
}
func (qb *QueryBuilder) NotBetween(field string, start, end any) *QueryBuilder {
return qb.where(Not(Between(field, start, end)))
}
func (qb *QueryBuilder) Contains(field string, pattern string) *QueryBuilder {
return qb.where(Contains(field, pattern))
}
func (qb *QueryBuilder) NotContains(field string, pattern string) *QueryBuilder {
return qb.where(NotContains(field, pattern))
}
func (qb *QueryBuilder) StartsWith(field string, pattern string) *QueryBuilder {
return qb.where(StartsWith(field, pattern))
}
func (qb *QueryBuilder) NotStartsWith(field string, pattern string) *QueryBuilder {
return qb.where(NotStartsWith(field, pattern))
}
func (qb *QueryBuilder) EndsWith(field string, pattern string) *QueryBuilder {
return qb.where(EndsWith(field, pattern))
}
func (qb *QueryBuilder) NotEndsWith(field string, pattern string) *QueryBuilder {
return qb.where(NotEndsWith(field, pattern))
}
func (qb *QueryBuilder) IsNull(field string) *QueryBuilder {
return qb.where(IsNull(field))
}
func (qb *QueryBuilder) NotNull(field string) *QueryBuilder {
return qb.where(NotNull(field))
}
// OrderBy 设置排序字段(升序)
// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误
func (qb *QueryBuilder) OrderBy(field string) *QueryBuilder {
qb.orderBy = field
qb.orderDesc = false
return qb
}
// OrderByDesc 设置排序字段(降序)
// 仅支持 "_seq" 或有索引的字段,使用其他字段会返回错误
func (qb *QueryBuilder) OrderByDesc(field string) *QueryBuilder {
qb.orderBy = field
qb.orderDesc = true
return qb
}
// Offset 设置跳过的记录数
// 用于分页查询,跳过前 n 条记录
func (qb *QueryBuilder) Offset(n int) *QueryBuilder {
if n < 0 {
n = 0
}
qb.offset = n
return qb
}
// Limit 设置返回的最大记录数
// 用于分页查询,最多返回 n 条记录
// n = 0 表示无限制
func (qb *QueryBuilder) Limit(n int) *QueryBuilder {
if n < 0 {
n = 0
}
qb.limit = n
return qb
}
// Paginate 执行分页查询并返回结果、总记录数和错误
// page: 页码,从 1 开始
// pageSize: 每页记录数
// 返回值:
// - rows: 当前页的数据
// - total: 满足条件的总记录数(用于计算总页数)
// - err: 错误信息
//
// 注意:此方法会执行两次查询,第一次获取总数,第二次获取分页数据
func (qb *QueryBuilder) Paginate(page, pageSize int) (rows *Rows, total int, err error) {
if page < 1 {
page = 1
}
if pageSize < 0 {
pageSize = 0
}
// 1. 先获取总记录数(不应用分页)
// 创建一个新的 QueryBuilder 副本用于计数
countQb := &QueryBuilder{
conds: qb.conds,
fields: qb.fields,
table: qb.table,
orderBy: "", // 计数不需要排序
offset: 0, // 计数不应用分页
limit: 0,
}
countRows, err := countQb.Rows()
if err != nil {
return nil, 0, err
}
defer countRows.Close()
// 计算总数
total = countRows.Len()
// 2. 执行分页查询
qb.offset = (page - 1) * pageSize
qb.limit = pageSize
rows, err = qb.Rows()
if err != nil {
return nil, total, err
}
return rows, total, nil
}
// validateOrderBy 验证排序字段是否有效
func (qb *QueryBuilder) validateOrderBy() error {
if qb.orderBy == "" {
return nil // 没有设置排序,无需验证
}
// 允许使用 _seq
if qb.orderBy == "_seq" {
return nil
}
// 检查该字段是否有索引
if _, exists := qb.table.indexManager.GetIndex(qb.orderBy); exists {
return nil
}
// 不支持的字段
return fmt.Errorf("OrderBy only supports '_seq' or indexed fields, field '%s' is not indexed", qb.orderBy)
}
// Rows 返回所有匹配的数据(游标模式 - 惰性加载)
func (qb *QueryBuilder) Rows() (*Rows, error) {
if qb.table == nil {
return nil, fmt.Errorf("table is nil")
}
// 验证排序字段
if err := qb.validateOrderBy(); err != nil {
return nil, err
}
rows := &Rows{
schema: qb.table.schema,
fields: qb.fields,
qb: qb,
table: qb.table,
visited: make(map[int64]bool),
}
// 如果设置了排序,使用排序后的结果集
if qb.orderBy != "" {
return qb.rowsWithOrder(rows)
}
// 尝试使用索引优化查询
// 检查是否有可以使用索引的 Eq 条件
indexField, indexValue := qb.findIndexableCondition()
if indexField != "" {
// 使用索引查询(索引查询需要立即加载,因为需要从索引获取 seq 列表)
return qb.rowsWithIndex(rows, indexField, indexValue)
}
// 惰性加载:只初始化迭代器,不读取数据
// 1. 初始化 Active MemTable 迭代器
activeMemTable := qb.table.memtableManager.GetActive()
if activeMemTable != nil {
rows.memIterator = newMemtableIterator(activeMemTable.Keys())
}
// 2. 初始化 Immutable MemTables稍后在 Next() 中迭代)
rows.immutableIndex = 0
rows.immutableIterator = nil
// 3. 初始化 SST 文件迭代器
sstReaders := qb.table.sstManager.GetReaders()
rows.sstReaders = make([]*sstReader, len(sstReaders))
for i, reader := range sstReaders {
rows.sstReaders[i] = &sstReader{
keys: reader.GetAllKeys(),
index: 0,
}
}
rows.sstIndex = 0
// 不设置 cached让 Next() 使用惰性加载
rows.cached = false
return rows, nil
}
// findIndexableCondition 查找可以使用索引的条件Eq 操作)
func (qb *QueryBuilder) findIndexableCondition() (string, any) {
for _, cond := range qb.conds {
// 检查是否是 compare 类型且操作符是 "="
if cmp, ok := cond.(compare); ok && cmp.op == "=" {
// 检查该字段是否有索引
if idx, exists := qb.table.indexManager.GetIndex(cmp.field); exists && idx.IsReady() {
return cmp.field, cmp.right
}
}
}
return "", nil
}
// rowsWithIndex 使用索引查询数据
func (qb *QueryBuilder) rowsWithIndex(rows *Rows, indexField string, indexValue any) (*Rows, error) {
// 获取索引
idx, exists := qb.table.indexManager.GetIndex(indexField)
if !exists {
return nil, fmt.Errorf("index on field %s not found", indexField)
}
// 从索引获取 seq 列表
seqs, err := idx.Get(indexValue)
if err != nil {
return nil, fmt.Errorf("index lookup failed: %w", err)
}
// 如果没有结果,返回空结果集
if len(seqs) == 0 {
rows.cached = true
rows.cachedIndex = -1
rows.cachedRows = []*SSTableRow{}
return rows, nil
}
// 根据 seq 列表获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(seqs))
for _, seq := range seqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配所有其他条件(索引只能优化一个条件)
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// rowsWithOrder 使用排序返回数据
func (qb *QueryBuilder) rowsWithOrder(rows *Rows) (*Rows, error) {
if qb.orderBy == "_seq" {
// 按 _seq 排序
return qb.rowsOrderBySeq(rows)
}
// 按索引字段排序
return qb.rowsOrderByIndex(rows, qb.orderBy)
}
// rowsOrderBySeq 按 _seq 排序返回数据
func (qb *QueryBuilder) rowsOrderBySeq(rows *Rows) (*Rows, error) {
// 收集所有 seq从所有数据源
seqList := []int64{}
// 1. 从 Active MemTable 收集
activeMemTable := qb.table.memtableManager.GetActive()
if activeMemTable != nil {
seqList = append(seqList, activeMemTable.Keys()...)
}
// 2. 从 Immutable MemTables 收集
immutables := qb.table.memtableManager.GetImmutables()
for _, immutable := range immutables {
seqList = append(seqList, immutable.MemTable.Keys()...)
}
// 3. 从 SST 文件收集
sstReaders := qb.table.sstManager.GetReaders()
for _, reader := range sstReaders {
seqList = append(seqList, reader.GetAllKeys()...)
}
// 去重(使用 map
seqMap := make(map[int64]bool)
uniqueSeqs := []int64{}
for _, seq := range seqList {
if !seqMap[seq] {
seqMap[seq] = true
uniqueSeqs = append(uniqueSeqs, seq)
}
}
// 排序
if qb.orderDesc {
// 降序
sort.Slice(uniqueSeqs, func(i, j int) bool {
return uniqueSeqs[i] > uniqueSeqs[j]
})
} else {
// 升序
slices.Sort(uniqueSeqs)
}
// 按排序后的 seq 获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(uniqueSeqs))
for _, seq := range uniqueSeqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配过滤条件
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// rowsOrderByIndex 按索引字段排序返回数据
//
// 实现策略:
// 1. 使用 ForEach/ForEachDesc 从索引收集所有 (value, seqs) 对
// 2. 按字段值(而非哈希)对这些对进行排序
// 3. 按排序后的顺序获取数据
//
// 注意:虽然使用了索引,但需要在内存中排序所有索引条目。
// 对于大量唯一值的字段,内存开销可能较大。
func (qb *QueryBuilder) rowsOrderByIndex(rows *Rows, indexField string) (*Rows, error) {
// 获取索引
idx, exists := qb.table.indexManager.GetIndex(indexField)
if !exists {
return nil, fmt.Errorf("index on field %s not found", indexField)
}
// 检查索引是否准备就绪
if !idx.IsReady() {
return nil, fmt.Errorf("index on field %s is not ready", indexField)
}
// 用于收集索引条目的结构
type indexEntry struct {
value string
seqs []int64
}
// 收集所有索引条目
entries := []indexEntry{}
err := idx.ForEach(func(value string, seqs []int64) bool {
// 复制 seqs 避免引用问题
seqsCopy := make([]int64, len(seqs))
copy(seqsCopy, seqs)
entries = append(entries, indexEntry{
value: value,
seqs: seqsCopy,
})
return true
})
if err != nil {
return nil, fmt.Errorf("failed to iterate index: %w", err)
}
// 按字段值排序(而非哈希)
if qb.orderDesc {
// 降序
sort.Slice(entries, func(i, j int) bool {
return entries[i].value > entries[j].value
})
} else {
// 升序
sort.Slice(entries, func(i, j int) bool {
return entries[i].value < entries[j].value
})
}
// 按排序后的顺序收集所有 seq
allSeqs := []int64{}
for _, entry := range entries {
allSeqs = append(allSeqs, entry.seqs...)
}
// 根据 seq 列表获取数据
rows.cachedRows = make([]*SSTableRow, 0, len(allSeqs))
for _, seq := range allSeqs {
row, err := qb.table.Get(seq)
if err != nil {
continue // 跳过获取失败的记录
}
// 检查是否匹配所有其他条件
if qb.Match(row.Data) {
rows.cachedRows = append(rows.cachedRows, row)
}
}
// 应用 offset 和 limit
rows.cachedRows = qb.applyOffsetLimit(rows.cachedRows)
// 使用缓存模式
rows.cached = true
rows.cachedIndex = -1
return rows, nil
}
// applyOffsetLimit 应用 offset 和 limit 到结果集
func (qb *QueryBuilder) applyOffsetLimit(rows []*SSTableRow) []*SSTableRow {
// 如果没有设置 offset 和 limit直接返回
if qb.offset == 0 && qb.limit == 0 {
return rows
}
// 应用 offset
if qb.offset > 0 {
if qb.offset >= len(rows) {
return []*SSTableRow{}
}
rows = rows[qb.offset:]
}
// 应用 limit
if qb.limit > 0 && qb.limit < len(rows) {
rows = rows[:qb.limit]
}
return rows
}
// First 返回第一个匹配的数据
func (qb *QueryBuilder) First() (*Row, error) {
rows, err := qb.Rows()
if err != nil {
return nil, err
}
defer rows.Close()
return rows.First()
}
// Last 返回最后一个匹配的数据
func (qb *QueryBuilder) Last() (*Row, error) {
rows, err := qb.Rows()
if err != nil {
return nil, err
}
defer rows.Close()
return rows.Last()
}
// Scan 扫描结果到指定的变量
func (qb *QueryBuilder) Scan(value any) error {
rows, err := qb.Rows()
if err != nil {
return err
}
defer rows.Close()
return rows.Scan(value)
}
type Row struct {
schema *Schema
fields []string // 要选择的字段nil 表示选择所有字段
inner *SSTableRow
}
// Data 获取行数据(根据 Select 过滤字段)
func (r *Row) Data() map[string]any {
if r.inner == nil {
return nil
}
// 如果没有指定字段,返回所有数据(包括 _seq 和 _time
if len(r.fields) == 0 {
result := make(map[string]any)
result["_seq"] = r.inner.Seq
result["_time"] = r.inner.Time
maps.Copy(result, r.inner.Data)
return result
}
// 根据指定的字段过滤
result := make(map[string]any)
for _, field := range r.fields {
if field == "_seq" {
result["_seq"] = r.inner.Seq
} else if field == "_time" {
result["_time"] = r.inner.Time
} else if val, ok := r.inner.Data[field]; ok {
result[field] = val
}
}
return result
}
// Seq 获取行序列号
func (r *Row) Seq() int64 {
if r.inner == nil {
return 0
}
return r.inner.Seq
}
// Scan 扫描行数据到指定的变量
func (r *Row) Scan(value any) error {
if r.inner == nil {
return fmt.Errorf("row is nil")
}
// 使用 r.Data() 而不是 r.inner.Data这样会应用字段过滤
data, err := json.Marshal(r.Data())
if err != nil {
return fmt.Errorf("marshal row data: %w", err)
}
err = json.Unmarshal(data, value)
if err != nil {
return fmt.Errorf("unmarshal to target: %w", err)
}
return nil
}
// Rows 游标模式的结果集(惰性加载)
type Rows struct {
schema *Schema
fields []string // 要选择的字段nil 表示选择所有字段
qb *QueryBuilder
table *Table
// 迭代状态
currentRow *Row
err error
closed bool
visited map[int64]bool // 已访问的 seq用于去重
// 数据源迭代器
memIterator *memtableIterator
immutableIndex int
immutableIterator *memtableIterator
sstIndex int
sstReaders []*sstReader
// 缓存模式(用于 Collect/Data 等方法)
cached bool
cachedRows []*SSTableRow
cachedIndex int // 缓存模式下的迭代位置
// 分页状态(惰性模式)
skippedCount int // 已跳过的记录数(用于 offset
returnedCount int // 已返回的记录数(用于 limit
}
// memtableIterator 包装 MemTable 的迭代器
type memtableIterator struct {
keys []int64
index int
}
func newMemtableIterator(keys []int64) *memtableIterator {
return &memtableIterator{
keys: keys,
index: -1,
}
}
func (m *memtableIterator) next() (int64, bool) {
m.index++
if m.index >= len(m.keys) {
return 0, false
}
return m.keys[m.index], true
}
// peek 查看下一个 seq但不推进指针
func (m *memtableIterator) peek() int64 {
nextIndex := m.index + 1
if nextIndex >= len(m.keys) {
return -1
}
return m.keys[nextIndex]
}
// sstReader 包装 SST Reader 的迭代状态
type sstReader struct {
keys []int64 // 文件中实际存在的 key 列表(已排序)
index int // 当前迭代位置
}
// Next 移动到下一行,返回是否还有数据
func (r *Rows) Next() bool {
if r.closed {
return false
}
if r.err != nil {
return false
}
// 如果是缓存模式,使用缓存的数据
if r.cached {
return r.nextFromCache()
}
// 惰性模式:从数据源读取
return r.next()
}
// next 从数据源读取下一条匹配的记录(惰性加载的核心逻辑)
// 使用归并排序,从所有数据源中选择最小的 seq
func (r *Rows) next() bool {
for {
// 初始化 Immutable 迭代器(如果需要)
if r.immutableIterator == nil && r.immutableIndex < len(r.table.memtableManager.GetImmutables()) {
immutables := r.table.memtableManager.GetImmutables()
if r.immutableIndex < len(immutables) {
r.immutableIterator = newMemtableIterator(immutables[r.immutableIndex].MemTable.Keys())
}
}
// 收集所有数据源的下一个 seq使用 peek不推进指针
minSeq := int64(-1)
minSource := -1 // 0=mem, 1=immutable, 2+=sst
// 1. 检查 Active MemTable
if r.memIterator != nil {
if seq := r.memIterator.peek(); seq != -1 {
if minSeq == -1 || seq < minSeq {
minSeq = seq
minSource = 0
}
}
}
// 2. 检查 Immutable MemTables
if r.immutableIterator != nil {
if seq := r.immutableIterator.peek(); seq != -1 {
if minSeq == -1 || seq < minSeq {
minSeq = seq
minSource = 1
}
}
}
// 3. 检查所有 SST 文件
for i, sstReader := range r.sstReaders {
if sstReader.index < len(sstReader.keys) {
seq := sstReader.keys[sstReader.index]
if minSeq == -1 || seq < minSeq {
minSeq = seq
minSource = 2 + i
}
}
}
// 如果没有找到任何数据源,说明迭代结束
if minSource == -1 {
return false
}
// 从选定的数据源推进指针
switch minSource {
case 0: // Active MemTable
r.memIterator.next()
if r.memIterator.peek() == -1 {
r.memIterator = nil
}
case 1: // Immutable MemTable
r.immutableIterator.next()
if r.immutableIterator.peek() == -1 {
r.immutableIterator = nil
r.immutableIndex++
}
default: // SST 文件
sstIndex := minSource - 2
r.sstReaders[sstIndex].index++
}
// 如果该 seq 已访问过(去重),继续下一轮
if r.visited[minSeq] {
continue
}
// 获取并验证该记录
row, err := r.table.Get(minSeq)
if err != nil {
r.visited[minSeq] = true
continue
}
// 检查是否匹配过滤条件
if !r.qb.Match(row.Data) {
r.visited[minSeq] = true
continue
}
// 应用 offset跳过前 N 条记录
if r.qb.offset > 0 && r.skippedCount < r.qb.offset {
r.skippedCount++
r.visited[minSeq] = true
continue
}
// 应用 limit达到返回上限后停止
if r.qb.limit > 0 && r.returnedCount >= r.qb.limit {
return false
}
// 找到匹配的记录
r.visited[minSeq] = true
r.returnedCount++
r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row}
return true
}
}
// nextFromCache 从缓存中获取下一条记录
func (r *Rows) nextFromCache() bool {
r.cachedIndex++
if r.cachedIndex >= len(r.cachedRows) {
return false
}
r.currentRow = &Row{
schema: r.schema,
fields: r.fields,
inner: r.cachedRows[r.cachedIndex],
}
return true
}
// Row 获取当前行
func (r *Rows) Row() *Row {
return r.currentRow
}
// Err 返回错误
func (r *Rows) Err() error {
return r.err
}
// Close 关闭游标
func (r *Rows) Close() error {
r.closed = true
return nil
}
// ensureCached 确保所有数据已被加载到缓存
func (r *Rows) ensureCached() {
if r.cached {
return
}
// 使用私有的 next() 方法直接从数据源读取所有剩余数据
// 这样避免了与 Next() 的循环调用问题
// 注意:如果之前已经调用过 Next(),部分数据已经被消耗,只能缓存剩余数据
for r.next() {
if r.currentRow != nil && r.currentRow.inner != nil {
r.cachedRows = append(r.cachedRows, r.currentRow.inner)
}
}
// 标记为已缓存,重置迭代位置
r.cached = true
r.cachedIndex = -1
}
// Len 返回总行数(需要完全扫描)
func (r *Rows) Len() int {
r.ensureCached()
return len(r.cachedRows)
}
// Collect 收集所有结果到切片
func (r *Rows) Collect() []map[string]any {
r.ensureCached()
var results []map[string]any
for _, row := range r.cachedRows {
results = append(results, row.Data)
}
return results
}
// Data 获取所有行的数据(向后兼容)
func (r *Rows) Data() []map[string]any {
return r.Collect()
}
// Scan 扫描所有行数据到指定的变量
// 智能判断目标类型:
// - 如果目标是切片:扫描所有行
// - 如果目标是结构体/指针:只扫描第一行
func (r *Rows) Scan(value any) error {
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Pointer {
return fmt.Errorf("scan target must be a pointer")
}
elem := rv.Elem()
kind := elem.Kind()
// 如果目标是切片,扫描所有行
if kind == reflect.Slice {
data, err := json.Marshal(r.Collect())
if err != nil {
return fmt.Errorf("marshal rows data: %w", err)
}
err = json.Unmarshal(data, value)
if err != nil {
return fmt.Errorf("unmarshal to target: %w", err)
}
return nil
}
// 否则,只扫描第一行
row, err := r.First()
if err != nil {
return err
}
return row.Scan(value)
}
// First 获取第一行
func (r *Rows) First() (*Row, error) {
// 尝试获取第一条记录(不使用缓存)
if r.Next() {
return r.currentRow, nil
}
return nil, fmt.Errorf("no rows")
}
// Last 获取最后一行
func (r *Rows) Last() (*Row, error) {
r.ensureCached()
if len(r.cachedRows) == 0 {
return nil, fmt.Errorf("no rows")
}
return &Row{
schema: r.schema,
fields: r.fields,
inner: r.cachedRows[len(r.cachedRows)-1],
}, nil
}
// Count 返回总行数(别名)
func (r *Rows) Count() int {
return r.Len()
}