diff --git a/CLAUDE.md b/CLAUDE.md index 756d18b..56b2127 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,12 +1,12 @@ # CLAUDE.md -本文件为 Claude Code (claude.ai/code) 提供在本仓库中工作的指导。 +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## 项目概述 SRDB 是一个用 Go 编写的高性能 Append-Only 时序数据库引擎。它使用简化的 LSM-tree 架构,结合 WAL + MemTable + mmap B+Tree SST 文件,针对高并发写入(200K+ 写/秒)和快速查询(1-5ms)进行了优化。 -**模块**: `code.tczkiot.com/srdb` +**模块**: `code.tczkiot.com/wlw/srdb` ## 构建和测试 @@ -14,22 +14,59 @@ SRDB 是一个用 Go 编写的高性能 Append-Only 时序数据库引擎。它 # 运行所有测试 go test -v ./... -# 运行指定包的测试 -go test -v ./engine -go test -v ./compaction -go test -v ./query +# 运行单个测试 +go test -v -run TestSSTable +go test -v -run TestTable -# 运行指定的测试 -go test -v ./engine -run TestEngineBasic +# 运行性能测试 +go test -bench=. -benchmem -# 构建示例程序 -go build ./examples/basic -go build ./examples/with_schema +# 运行带超时的测试(某些 compaction 测试需要较长时间) +go test -v -timeout 30s + +# 构建 WebUI 工具 +cd examples/webui +go build -o webui main.go +./webui serve --db ./data ``` ## 架构 -### 两层存储模型 +### 文件结构(扁平化设计) + +所有核心代码都在根目录下,采用扁平化结构: + +``` +srdb/ +├── database.go # 多表数据库管理 +├── table.go # 表管理(带 Schema) +├── errors.go # 错误定义和处理(统一错误码系统) +├── wal.go # WAL 实现(Write-Ahead Log) +├── memtable.go # MemTable(map + sorted slice,~130 行) +├── sstable.go # SSTable 文件(读写器、管理器、二进制编码) +├── btree.go # B+Tree 索引(构建器、读取器,4KB 节点) +├── version.go # 版本控制(MANIFEST 管理) +├── compaction.go # Compaction 压缩合并 +├── schema.go # Schema 定义与验证 +├── index.go # 二级索引管理器 +├── index_btree.go # 索引 B+Tree 实现 +└── query.go # 查询构建器和表达式求值 +``` + +**运行时数据目录**: +``` +database_dir/ +├── database.meta # 数据库元数据(JSON) +├── MANIFEST # 全局版本控制 +└── table_name/ # 每表一个目录 + ├── schema.json # 表 Schema 定义 + ├── MANIFEST # 表级版本控制 + ├── 000001.wal # WAL 文件 + ├── 000001.sst # SST 文件(B+Tree 索引 + 二进制数据) + └── idx_field.sst # 二级索引文件(可选) +``` + +### 核心架构:简化的两层模型 与传统的多层 LSM 树不同,SRDB 使用简化的两层架构: @@ -39,12 +76,12 @@ go build ./examples/with_schema ### 核心数据流 **写入路径**: -1. Schema 验证(如果定义了) -2. 生成序列号 (`_seq`) +1. Schema 验证(强制要求,如果表有 Schema) +2. 生成序列号 (`_seq`,原子递增的 int64) 3. 追加写入 WAL(顺序写) 4. 插入到 Active MemTable(map + 有序 slice) -5. 当 MemTable 超过阈值(默认 64MB)时,切换到新的 Active MemTable 并异步将 Immutable 刷新到 SST -6. 更新二级索引(如果已创建) +5. 当 MemTable 超过阈值(默认 64MB)时,切换到新的 Active MemTable 并异步刷新 Immutable 到 SST +6. 更新二级索引(如果字段标记为 Indexed) **读取路径**: 1. 检查 Active MemTable(O(1) map 查找) @@ -56,11 +93,11 @@ go build ./examples/with_schema 1. 如果是带 `=` 操作符的索引字段:使用二级索引 → 通过 seq 获取 2. 否则:带过滤条件的全表扫描(MemTable + SST) -### 关键设计选择 +### 关键设计决策 **MemTable: `map[int64][]byte + sorted []int64`** -- 为什么不用 SkipList?实现更简单(130 行),Put 和 Get 都是 O(1) vs O(log N) -- 权衡:插入时需要重新排序 keys slice(但实际上仍然更快) +- 为什么不用 SkipList?实现更简单(~130 行),Put 和 Get 都是 O(1) vs O(log N) +- 权衡:插入新 key 时需要重新排序 keys slice(但实际上仍然更快) - Active MemTable + 多个 Immutable MemTables(正在刷新中) **SST 格式: 4KB 节点的 B+Tree** @@ -68,7 +105,13 @@ go build ./examples/with_schema - 支持高效的 mmap 访问和零拷贝读取 - 内部节点:keys + 子节点指针 - 叶子节点:keys + 数据偏移量/大小 -- 数据块:Snappy 压缩的 JSON 行 +- 数据块:二进制编码(使用 Schema 时)或 JSON(无 Schema 时) + +**二进制编码格式**: +- Magic Number: `0x524F5731` ("ROW1") +- 格式:`[Magic: 4B][Seq: 8B][Time: 8B][FieldCount: 2B][FieldOffsetTable][FieldData]` +- 按字段分别编码,支持部分字段读取(`GetPartial`) +- 无压缩(优先查询性能,保持 mmap 零拷贝) **mmap 而非 read() 系统调用** - 对 SST 文件的零拷贝访问 @@ -80,125 +123,70 @@ go build ./examples/with_schema - 相同 seq 的新记录覆盖旧记录 - Compaction 合并文件并按 seq 去重(保留最新的,按时间戳) -## 目录结构 +## 常见开发模式 -``` -srdb/ -├── database.go # 多表数据库管理 -├── table.go # 带 schema 的表 -├── engine/ # 核心存储引擎(583 行) -│ └── engine.go -├── wal/ # 预写日志 -│ ├── wal.go # WAL 实现(208 行) -│ └── manager.go # 多 WAL 管理 -├── memtable/ # 内存表 -│ ├── memtable.go # MemTable(130 行) -│ └── manager.go # Active + Immutable 管理 -├── sst/ # SSTable 文件 -│ ├── format.go # 文件格式定义 -│ ├── writer.go # SST 写入器 -│ ├── reader.go # mmap 读取器(147 行) -│ ├── manager.go # SST 文件管理 -│ └── encoding.go # Snappy 压缩 -├── btree/ # B+Tree 索引 -│ ├── node.go # 4KB 节点结构 -│ ├── builder.go # B+Tree 构建器(125 行) -│ └── reader.go # B+Tree 读取器 -├── manifest/ # 版本控制 -│ ├── version_set.go # 版本管理 -│ ├── version_edit.go # 原子更新 -│ ├── version.go # 文件元数据 -│ ├── manifest_writer.go -│ └── manifest_reader.go -├── compaction/ # 后台压缩 -│ ├── manager.go # Compaction 调度器 -│ ├── compactor.go # 合并执行器 -│ └── picker.go # 文件选择策略 -├── index/ # 二级索引 -│ ├── index.go # 字段级索引 -│ └── manager.go # 索引生命周期 -├── query/ # 查询系统 -│ ├── builder.go # 流式查询 API -│ └── expr.go # 表达式求值 -└── schema/ # Schema 验证 - ├── schema.go # 类型定义和验证 - └── examples.go # Schema 示例 -``` +### Schema 系统(强制要求) -**运行时数据目录**(例如 `./mydb/`): -``` -database_dir/ -├── database.meta # 数据库元数据(JSON) -├── MANIFEST # 全局版本控制 -└── table_name/ # 每表目录 - ├── schema.json # 表 schema - ├── MANIFEST # 表级版本控制 - ├── wal/ # WAL 文件(*.wal) - ├── sst/ # SST 文件(*.sst) - └── index/ # 二级索引(idx_*.sst) -``` - -## 常见模式 - -### 使用 Engine - -`Engine` 是核心存储层。修改引擎行为时: - -- 所有写入都经过 `Insert()` → WAL → MemTable → (异步刷新到 SST) -- 读取经过 `Get(seq)` → 检查 MemTable → 检查 SST 文件 -- `switchMemTable()` 创建新的 Active MemTable 并异步刷新旧的 -- `flushImmutable()` 将 MemTable 写入 SST 并更新 MANIFEST -- 后台 compaction 通过 `compactionManager` 运行 - -### Schema 和验证 - -Schema 是可选的,但建议在生产环境使用: +从最近的重构开始,Schema 是**强制**的,不再支持无 Schema 模式: ```go -schema := schema.NewSchema("users"). - AddField("name", schema.FieldTypeString, false, "用户名"). - AddField("age", schema.FieldTypeInt64, false, "用户年龄"). - AddField("email", schema.FieldTypeString, true, "邮箱(索引)") +schema := NewSchema("users", []Field{ + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "用户名"}, + {Name: "age", Type: FieldTypeInt64, Indexed: false, Comment: "年龄"}, + {Name: "email", Type: FieldTypeString, Indexed: true, Comment: "邮箱(索引)"}, +}) table, _ := db.CreateTable("users", schema) ``` -- Schema 在 `Insert()` 时验证类型和必填字段 +- Schema 在 `Insert()` 时强制验证类型和必填字段 - 索引字段(`Indexed: true`)自动创建二级索引 -- Schema 持久化到 `table_dir/schema.json` +- Schema 持久化到 `table_dir/schema.json`,包含校验和防篡改 +- 支持的类型:`FieldTypeString`, `FieldTypeInt64`, `FieldTypeBool`, `FieldTypeFloat` ### Query Builder -对于带条件的查询,始终使用 `QueryBuilder`: +对于带条件的查询,使用链式 API: ```go -qb := query.NewQueryBuilder() -qb.Where("age", query.OpGreater, 18). - Where("city", query.OpEqual, "Beijing") -rows, _ := table.Query(qb) +// 简单查询 +rows, _ := table.Query().Eq("name", "Alice").Rows() + +// 复合条件 +rows, _ := table.Query(). + Eq("status", "active"). + Gte("age", 18). + Rows() + +// 字段选择(性能优化) +rows, _ := table.Query(). + Select("id", "name", "email"). + Eq("status", "active"). + Rows() + +// 游标模式 +rows, _ := table.Query().Rows() +defer rows.Close() +for rows.Next() { + row := rows.Row() + fmt.Println(row.Data()) +} ``` -- 支持操作符:`OpEqual`、`OpNotEqual`、`OpGreater`、`OpLess`、`OpPrefix`、`OpSuffix`、`OpContains` -- 支持 `WhereNot()` 进行否定 -- 支持 `And()` 和 `Or()` 逻辑 -- 当可用时自动使用二级索引(对于 `=` 条件) -- 如果没有索引,则回退到全表扫描 +支持的操作符:`Eq`, `NotEq`, `Lt`, `Gt`, `Lte`, `Gte`, `In`, `NotIn`, `Between`, `Contains`, `StartsWith`, `EndsWith`, `IsNull`, `NotNull` ### Compaction Compaction 在后台自动运行: -- **触发条件**: L0 文件数 > 阈值(默认 10) +- **触发条件**: L0 文件数 > 阈值(默认 4-10,根据层级) - **策略**: 合并重叠文件,从 L0 → L1、L1 → L2 等 +- **Score 计算**: `size / max_size` 或 `file_count / max_files` - **安全性**: 删除前验证文件是否存在,以防止数据丢失 - **去重**: 对于重复的 seq,保留最新记录(按时间戳) - **文件大小**: L0=2MB、L1=10MB、L2=50MB、L3=100MB、L4+=200MB -修改 compaction 逻辑时: -- `picker.go`: 选择要压缩的文件 -- `compactor.go`: 执行合并操作 -- `manager.go`: 调度和协调 compaction -- 删除前始终验证输入/输出文件是否存在(参见 `DoCompaction`) +修改 compaction 逻辑时,注意 `compaction.go` 中的文件选择和合并逻辑。 ### 版本控制(MANIFEST) @@ -212,57 +200,55 @@ MANIFEST 跟踪跨版本的 SST 文件元数据: 1. 分配文件编号:`versionSet.AllocateFileNumber()` 2. 创建带变更的 `VersionEdit` 3. 应用:`versionSet.LogAndApply(edit)` -4. 清理旧文件:`compactionManager.CleanupOrphanFiles()` +4. 清理旧文件(通过 GC 机制) + +### 错误处理 + +使用统一的错误码系统(`errors.go`): + +```go +// 创建错误 +err := NewError(ErrCodeTableNotFound, nil) + +// 带上下文包装错误 +err := WrapError(baseErr, "failed to get table %s", "users") + +// 错误判断 +if IsNotFound(err) { ... } +if IsCorrupted(err) { ... } +if IsClosed(err) { ... } + +// 获取错误码 +code := GetErrorCode(err) +``` + +- 错误码范围:1000-1999(通用)、2000-2999(数据库)、3000-3999(表)、4000-4999(Schema)等 +- 所有 panic 已替换为错误返回 +- 使用 `fmt.Errorf` 和 `%w` 进行错误链包装 ### 错误恢复 - **WAL 重放**: 启动时,所有 `*.wal` 文件被重放到 Active MemTable -- **孤儿文件清理**: 不在 MANIFEST 中的文件在启动时删除 -- **索引修复**: `verifyAndRepairIndexes()` 重建损坏的索引 +- **孤儿文件清理**: 不在 MANIFEST 中的文件在启动时删除(有年龄保护,避免误删最近写入的文件) +- **索引修复**: 自动验证和重建损坏的索引 - **优雅降级**: 表恢复失败会被记录但不会使数据库崩溃 -## 测试模式 - -测试按组件组织: - -- `engine/engine_test.go`: 基本引擎操作 -- `engine/engine_compaction_test.go`: Compaction 场景 -- `engine/engine_stress_test.go`: 并发压力测试 -- `compaction/compaction_test.go`: Compaction 正确性 -- `query/builder_test.go`: Query builder 功能 -- `schema/schema_test.go`: Schema 验证 - -为多线程操作编写测试时,使用 `sync.WaitGroup` 并用多个 goroutine 测试(参见 `engine_stress_test.go`)。 - -## 性能特性 - -- **写入吞吐量**: 200K+ 写/秒(多线程),50K 写/秒(单线程) -- **写入延迟**: < 1ms(p99) -- **查询延迟**: < 0.1ms(MemTable),1-5ms(SST 热数据),3-5ms(冷数据) -- **内存使用**: < 150MB(64MB MemTable + 开销) -- **压缩率**: Snappy 约 50% - -优化时: -- 批量写入以减少 WAL 同步开销 -- 对经常查询的字段创建索引 -- 监控 MemTable 刷新频率(不应太频繁) -- 根据写入模式调整 compaction 阈值 - ## 重要实现细节 -### 序列号 +### 序列号系统 - `_seq` 是单调递增的 int64(原子操作) - 充当主键和时间戳排序 - 永不重用(append-only) -- compaction 期间,相同 seq 值的较新记录优先 +- Compaction 期间,相同 seq 值的较新记录优先(按 `_time` 排序) -### 并发 +### 并发控制 -- `Engine.mu`: 保护元数据和 SST reader 列表 -- `Engine.flushMu`: 确保一次只有一个 flush +- `Table.mu`: 保护表级元数据 +- `SSTableManager.mu`: RWMutex,保护 SST reader 列表 - `MemTable.mu`: RWMutex,支持并发读、独占写 - `VersionSet.mu`: 保护版本状态 +- 无全局锁,细粒度锁设计 ### 文件格式 @@ -273,7 +259,7 @@ CRC32 (4B) | Length (4B) | Type (1B) | Seq (8B) | DataLen (4B) | Data (N bytes) **SST 文件**: ``` -Header (256B) | B+Tree Index | Data Blocks (Snappy compressed) +Header (256B) | B+Tree Index (4KB nodes) | Data Blocks (Binary format) ``` **B+Tree 节点**(4KB 固定): @@ -281,11 +267,51 @@ Header (256B) | B+Tree Index | Data Blocks (Snappy compressed) Header (32B) | Keys (8B each) | Pointers/Offsets (8B each) | Padding ``` +**二进制行格式** (ROW1): +``` +Magic (4B) | Seq (8B) | Time (8B) | FieldCount (2B) | +[FieldOffset, FieldSize] × N | FieldData × N +``` + +## 性能特性 + +- **写入吞吐量**: 200K+ 写/秒(多线程),50K 写/秒(单线程) +- **写入延迟**: < 1ms(p99) +- **查询延迟**: < 0.1ms(MemTable),1-5ms(SST 热数据),3-5ms(冷数据) +- **内存使用**: < 150MB(64MB MemTable + 开销) +- **压缩**: 未使用(优先查询性能) + +优化建议: +- 批量写入以减少 WAL 同步开销 +- 对经常查询的字段创建索引 +- 使用 `Select()` 只查询需要的字段 +- 监控 MemTable 刷新频率(不应太频繁) +- 根据写入模式调整 Compaction 阈值 + ## 常见陷阱 -- Schema 验证仅在向 `Engine.Open()` 提供 schema 时才应用 -- 索引必须通过 `CreateIndex(field)` 显式创建(非自动) -- 带 schema 的 QueryBuilder 需要调用 `WithSchema()` 或让引擎设置它 -- Compaction 可能会暂时增加磁盘使用(合并期间旧文件和新文件共存) -- MemTable flush 是异步的;关闭时可能需要等待 immutable flush 完成 -- mmap 文件可能显示较大的虚拟内存使用(这是正常的,不是实际 RAM) +- **Schema 是强制的**: 所有表必须定义 Schema,不再支持无 Schema 模式 +- **索引非自动创建**: 需要在 Schema 中显式标记 `Indexed: true` +- **类型严格**: Schema 验证严格,int 和 int64 需要正确匹配 +- **Compaction 磁盘占用**: 合并期间旧文件和新文件共存,会暂时增加磁盘使用 +- **MemTable flush 异步**: 关闭时需要等待 immutable flush 完成 +- **mmap 虚拟内存**: 可能显示较大的虚拟内存使用(正常,OS 管理,不是实际 RAM) +- **无 panic**: 所有 panic 已替换为错误返回,需要正确处理错误 +- **废弃代码**: `SSTableCompressionNone` 等常量已删除 + +## Web UI + +项目包含功能完善的 Web 管理界面: + +```bash +cd examples/webui +go run main.go serve --db /path/to/database --port 8080 +``` + +功能: +- 表管理和数据浏览 +- Manifest 可视化(LSM-Tree 结构) +- 实时 Compaction 监控 +- 深色/浅色主题 + +详见 `examples/webui/README.md` diff --git a/DESIGN.md b/DESIGN.md index 2afda0e..2ed0d69 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -1,6 +1,6 @@ # SRDB 设计文档:WAL + mmap B+Tree -> 模块名:`code.tczkiot.com/srdb` +> 模块名:`code.tczkiot.com/wlw/srdb` > 一个高性能的 Append-Only 时序数据库引擎 ## 🎯 设计目标 @@ -19,10 +19,10 @@ │ SRDB Architecture │ ├─────────────────────────────────────────────────────────────┤ │ Application Layer │ -│ ┌───────────────┐ ┌──────────┐ ┌───────────┐ │ -│ │ Database │->│ Table │->│ Engine │ │ -│ │ (Multi-Table) │ │ (Schema) │ │ (Storage) │ │ -│ └───────────────┘ └──────────┘ └───────────┘ │ +│ ┌───────────────┐ ┌──────────────────────────┐ │ +│ │ Database │->│ Table │ │ +│ │ (Multi-Table) │ │ (Schema + Storage) │ │ +│ └───────────────┘ └──────────────────────────┘ │ ├─────────────────────────────────────────────────────────────┤ │ Write Path (High Concurrency) │ │ ┌─────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ @@ -38,10 +38,10 @@ ├─────────────────────────────────────────────────────────────┤ │ Storage Layer (Persistent) │ │ ┌─────────────────────────────────────────────────┐ │ -│ │ SST Files (B+Tree Format + Compression) │ │ +│ │ SST Files (B+Tree Format + Binary Encoding) │ │ │ │ ┌─────────────────────────────────────────┐ │ │ │ │ │ File Header (256 bytes) │ │ │ -│ │ │ - Magic, Version, Compression │ │ │ +│ │ │ - Magic, Version, Metadata │ │ │ │ │ │ - MinKey, MaxKey, RowCount │ │ │ │ │ ├─────────────────────────────────────────┤ │ │ │ │ │ B+Tree Index (4 KB nodes) │ │ │ @@ -49,8 +49,9 @@ │ │ │ - Internal Nodes (Order=200) │ │ │ │ │ │ - Leaf Nodes → Data Offset │ │ │ │ │ ├─────────────────────────────────────────┤ │ │ -│ │ │ Data Blocks (Snappy Compressed) │ │ │ -│ │ │ - JSON serialized rows │ │ │ +│ │ │ Data Blocks (Binary Format) │ │ │ +│ │ │ - 有 Schema: 二进制编码 │ │ │ +│ │ │ - 无 Schema: JSON 格式 │ │ │ │ │ └─────────────────────────────────────────┘ │ │ │ │ │ │ │ │ Secondary Indexes (Optional) │ │ @@ -88,59 +89,32 @@ ``` srdb/ ← 项目根目录 -├── go.mod ← 模块定义: code.tczkiot.com/srdb +├── go.mod ← 模块定义: code.tczkiot.com/wlw/srdb ├── DESIGN.md ← 本设计文档 +├── CLAUDE.md ← Claude Code 指导文档 +│ ├── database.go ← 数据库管理 (多表) -├── table.go ← 表管理 +├── table.go ← 表管理 (带 Schema) +├── errors.go ← 错误定义和处理 │ -├── engine/ ← 存储引擎 -│ └── engine.go ← 核心引擎实现 (583 行) +├── wal.go ← WAL 实现 (Write-Ahead Log) +├── memtable.go ← MemTable 实现 (map + sorted slice) +├── sstable.go ← SSTable 文件 (读写器、管理器、编码) +├── btree.go ← B+Tree 索引 (构建器、读取器) +├── version.go ← 版本控制 (MANIFEST 管理) +├── compaction.go ← Compaction 压缩合并 │ -├── wal/ ← Write-Ahead Log -│ ├── wal.go ← WAL 实现 (208 行) -│ └── manager.go ← WAL 管理器 +├── schema.go ← Schema 定义与验证 +├── index.go ← 二级索引管理器 +├── index_btree.go ← 索引 B+Tree 实现 +├── query.go ← 查询构建器和表达式求值 │ -├── memtable/ ← 内存表 -│ ├── memtable.go ← MemTable 实现 (130 行) -│ └── manager.go ← MemTable 管理器 (多版本) +├── examples/ ← 示例程序目录 +│ ├── webui/ ← Web UI 管理工具 +│ └── ... (其他示例) │ -├── sst/ ← SSTable 文件 -│ ├── format.go ← 文件格式定义 -│ ├── writer.go ← SST 写入器 -│ ├── reader.go ← SST 读取器 (mmap, 147 行) -│ ├── manager.go ← SST 管理器 -│ └── encoding.go ← 序列化/压缩 -│ -├── btree/ ← B+Tree 索引 -│ ├── node.go ← 节点定义 (4KB) -│ ├── builder.go ← B+Tree 构建器 (125 行) -│ └── reader.go ← B+Tree 读取器 -│ -├── manifest/ ← 版本控制 -│ ├── version_set.go ← 版本集合 -│ ├── version_edit.go ← 版本变更 -│ ├── version.go ← 版本信息 -│ ├── manifest_writer.go ← MANIFEST 写入 -│ └── manifest_reader.go ← MANIFEST 读取 -│ -├── compaction/ ← 压缩合并 -│ ├── manager.go ← Compaction 管理器 -│ ├── compactor.go ← 压缩执行器 -│ └── picker.go ← 文件选择策略 -│ -├── index/ ← 二级索引 (新增) -│ ├── index.go ← 索引实现 -│ ├── manager.go ← 索引管理器 -│ └── README.md ← 索引使用文档 -│ -├── query/ ← 查询系统 (新增) -│ ├── builder.go ← 查询构建器 -│ └── expr.go ← 表达式求值 -│ -└── schema/ ← Schema 系统 (新增) - ├── schema.go ← Schema 定义与验证 - ├── examples.go ← Schema 示例 - └── README.md ← Schema 使用文档 +└── webui/ ← Web UI 静态资源 + └── ... ``` ### 运行时数据目录结构 @@ -162,7 +136,7 @@ database_dir/ ← 数据库目录 │ ├── 000002.sst │ └── 000003.sst │ - └── index/ ← 索引目录 (可选) + └── idx/ ← 索引目录 (可选) ├── idx_name.sst ← 字段 name 的索引 └── idx_email.sst ← 字段 email 的索引 ``` @@ -283,14 +257,20 @@ B+Tree 节点格式: │ ├─ Child Pointer 2 (8 bytes) │ │ └─ ... │ │ │ -│ Leaf Node: │ -│ ├─ Data Offset 1 (8 bytes) │ -│ ├─ Data Size 1 (4 bytes) │ -│ ├─ Data Offset 2 (8 bytes) │ -│ ├─ Data Size 2 (4 bytes) │ +│ Leaf Node (interleaved storage): │ +│ ├─ (Offset, Size) Pair 1 │ +│ │ ├─ Data Offset 1 (8 bytes) │ +│ │ └─ Data Size 1 (4 bytes) │ +│ ├─ (Offset, Size) Pair 2 │ +│ │ ├─ Data Offset 2 (8 bytes) │ +│ │ └─ Data Size 2 (4 bytes) │ │ └─ ... │ └─────────────────────────────────────┘ +解释: + +- interleaved storage: 交叉存储 + 优势: ✅ 固定大小 (4 KB) - 对齐页面 ✅ 可以直接 mmap 访问 @@ -513,7 +493,6 @@ type Table struct { name string dir string schema *schema.Schema - engine *engine.Engine } 使用示例: @@ -571,7 +550,7 @@ Flush 流程 (后台): ↓ 4. 构建 B+Tree 索引 ↓ -5. 写入数据块 (Snappy 压缩) +5. 写入数据块 (二进制格式) ↓ 6. 写入 B+Tree 索引 ↓ @@ -630,22 +609,23 @@ Flush 流程 (后台): ### 代码规模 ``` -核心代码: 5399 行 (不含测试和示例) -├── engine: 583 行 -├── wal: 208 行 -├── memtable: 130 行 -├── sst: 147 行 (reader) -├── btree: 125 行 (builder) -├── manifest: ~500 行 -├── compaction: ~400 行 -├── index: ~400 行 -├── query: ~300 行 -├── schema: ~200 行 -└── database: ~300 行 +核心代码: ~13,000 行 (不含测试和示例) +├── table.go: 表管理和存储引擎 +├── wal.go: WAL 实现 +├── memtable.go: MemTable 实现 +├── sstable.go: SSTable 文件读写 +├── btree.go: B+Tree 索引 +├── version.go: 版本控制 (MANIFEST) +├── compaction.go: Compaction 压缩 +├── index.go: 二级索引 +├── query.go: 查询构建器 +├── schema.go: Schema 验证 +├── errors.go: 错误处理 +└── database.go: 数据库管理 测试代码: ~2000+ 行 示例代码: ~1000+ 行 -总计: 8000+ 行 +总计: 16,000+ 行 ``` ### 写入性能 @@ -692,10 +672,10 @@ Flush 流程 (后台): ``` 示例 (100 万条记录,每条 200 bytes): - 原始数据: 200 MB -- Snappy 压缩: 100 MB (50% 压缩率) +- 二进制编码: ~180 MB (紧凑格式) - B+Tree 索引: 20 MB (10%) - 二级索引: 10 MB (可选) -- 总计: 130 MB (65% 压缩率) +- 总计: ~210 MB (无压缩) ``` ## 🔧 实现状态 @@ -704,14 +684,14 @@ Flush 流程 (后台): ``` 核心存储引擎: -- [✅] Schema 定义和解析 -- [✅] WAL 实现 (wal/) -- [✅] MemTable 实现 (memtable/,使用 map+slice) -- [✅] 基础的 Insert 和 Get -- [✅] SST 文件格式定义 (sst/format.go) -- [✅] B+Tree 构建器 (btree/) +- [✅] Schema 定义和解析 (schema.go) +- [✅] WAL 实现 (wal.go) +- [✅] MemTable 实现 (memtable.go,使用 map+slice) +- [✅] 基础的 Insert 和 Get (table.go) +- [✅] SST 文件格式定义 (sstable.go) +- [✅] B+Tree 构建器 (btree.go) - [✅] Flush 流程 (异步) -- [✅] mmap 查询 (sst/reader.go) +- [✅] mmap 查询 (sstable.go) ``` ### Phase 2: 优化和稳定 ✅ 已完成 @@ -721,9 +701,9 @@ Flush 流程 (后台): - [✅] 批量写入优化 - [✅] 并发控制优化 - [✅] 崩溃恢复 (WAL 重放) -- [✅] MANIFEST 管理 (manifest/) -- [✅] Compaction 实现 (compaction/) -- [✅] MemTable Manager (多版本管理) +- [✅] MANIFEST 管理 (version.go) +- [✅] Compaction 实现 (compaction.go) +- [✅] MemTable Manager (多版本管理,table.go) - [✅] 性能测试 (各种 *_test.go) - [✅] 文档完善 (README.md, DESIGN.md) ``` @@ -733,14 +713,15 @@ Flush 流程 (后台): ``` 高级功能: - [✅] 数据库和表管理 (database.go, table.go) -- [✅] Schema 系统 (schema/) -- [✅] 二级索引 (index/) -- [✅] 查询构建器 (query/) +- [✅] Schema 系统 (schema.go,强制要求) +- [✅] 二级索引 (index.go, index_btree.go) +- [✅] 查询构建器 (query.go) - [✅] 条件查询 (AND/OR/NOT) - [✅] 字符串匹配 (Contains/StartsWith/EndsWith) - [✅] 版本控制和自动修复 -- [✅] 统计信息 (engine.Stats()) -- [✅] 压缩和编码 (Snappy) +- [✅] 统计信息 (table.Stats()) +- [✅] 二进制编码和序列化 (ROW1 格式) +- [✅] 统一错误处理 (errors.go) ``` ### Phase 4: 示例和文档 ✅ 已完成 @@ -817,15 +798,15 @@ Flush 流程 (后台): - 优势: 列裁剪,压缩率高 - 劣势: 实现复杂,Flush 慢 -最终实现 (V3): 行式存储 + Snappy -- 优势: 实现简单,Flush 快 -- 劣势: 压缩率稍低 +最终实现 (V3): 行式存储 + 二进制格式 +- 优势: 实现简单,Flush 快,紧凑高效 +- 劣势: 相比列式压缩率稍低 权衡: - 追求简单和快速实现 -- 行式 + Snappy 已经有 50% 压缩率 +- 二进制格式已经足够紧凑 - 满足大多数时序数据场景 -- 如果未来需要,可以演进到列式 +- 如果未来需要,可以演进到列式或添加压缩 ``` ### 为什么用 B+Tree 而不是 LSM Tree? @@ -868,6 +849,33 @@ mmap 方式: ✅ OS 自动优化 ``` +### 为什么不使用压缩(Snappy/LZ4)? + +``` +压缩的优势: +- 减少磁盘空间 +- 可能减少 I/O + +压缩的劣势: +- CPU 开销(压缩/解压) +- 查询延迟增加 +- mmap 零拷贝失效(需要先解压) +- 实现复杂度增加 + +最终决策: 不使用压缩 +- 优先考虑查询性能 +- 保持 mmap 零拷贝优势 +- 二进制格式已经足够紧凑 +- 现代存储成本较低 +- 如果真需要压缩,可以在应用层或文件系统层实现 + +权衡: +✅ 查询延迟更低 +✅ 实现更简单 +✅ mmap 零拷贝有效 +❌ 磁盘占用稍大 +``` + ## 🎯 总结 SRDB 是一个功能完善的高性能 Append-Only 数据库引擎: @@ -883,7 +891,7 @@ SRDB 是一个功能完善的高性能 Append-Only 数据库引擎: **技术亮点:** - 简洁的 MemTable 实现 (map + sorted slice) - B+Tree 索引,4KB 节点对齐 -- Snappy 压缩,50% 压缩率 +- 高效的二进制编码格式 - 多版本 MemTable 管理 - 后台 Compaction - 版本控制和自动修复 @@ -904,12 +912,8 @@ SRDB 是一个功能完善的高性能 Append-Only 数据库引擎: - ❌ 传统 OLTP 系统 **项目成果:** -- 核心代码: 5399 行 -- 测试代码: 2000+ 行 -- 示例程序: 13 个完整示例 +- 核心代码: ~13,000 行 +- 测试代码: ~2,000+ 行 +- 示例程序: 13+ 个完整示例 - 文档: 完善的设计和使用文档 - 性能: 达到设计目标 - ---- - -**项目已完成并可用于生产环境!** 🎉 diff --git a/Makefile b/Makefile index dd059b8..39ad9d9 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help test test-verbose test-coverage test-race test-bench test-engine test-compaction test-btree test-memtable test-sstable test-wal test-version test-schema test-index test-database fmt fmt-check vet tidy verify clean build run-webui install-webui +.PHONY: help test test-verbose test-coverage test-race test-bench test-table test-compaction test-btree test-memtable test-sstable test-wal test-version test-schema test-index test-database fmt fmt-check vet tidy verify clean build run-webui install-webui # 默认目标 .DEFAULT_GOAL := help @@ -40,9 +40,9 @@ test-bench: ## 运行基准测试 @echo "$(GREEN)运行基准测试...$(RESET)" @go test -bench=. -benchmem $$(go list ./... | grep -v /examples/) -test-engine: ## 只运行 engine 测试 - @echo "$(GREEN)运行 engine 测试...$(RESET)" - @go test -v -run TestEngine +test-table: ## 只运行 table 测试 + @echo "$(GREEN)运行 table 测试...$(RESET)" + @go test -v -run TestTable test-compaction: ## 只运行 compaction 测试 @echo "$(GREEN)运行 compaction 测试...$(RESET)" diff --git a/README.md b/README.md index 0966e61..87fd27c 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ ### 安装 ```bash -go get code.tczkiot.com/srdb +go get code.tczkiot.com/wlw/srdb ``` ### 基本示例 @@ -57,7 +57,7 @@ package main import ( "fmt" "log" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) func main() { @@ -376,21 +376,19 @@ go run main.go serve --auto-insert ``` Database -├── Table -│ ├── Schema(表结构) -│ └── Engine(存储引擎) -│ ├── MemTable Manager -│ │ ├── Active MemTable -│ │ └── Immutable MemTables -│ ├── SSTable Manager -│ │ └── SST Files (Level 0-6) -│ ├── WAL Manager -│ │ └── Write-Ahead Log -│ ├── Version Manager -│ │ └── MVCC Versions -│ └── Compaction Manager -│ ├── Picker(选择策略) -│ └── Worker(执行合并) +├── Table (Schema + Storage) +│ ├── MemTable Manager +│ │ ├── Active MemTable +│ │ └── Immutable MemTables +│ ├── SSTable Manager +│ │ └── SST Files (Level 0-6) +│ ├── WAL Manager +│ │ └── Write-Ahead Log +│ ├── Version Manager +│ │ └── MVCC Versions +│ └── Compaction Manager +│ ├── Picker(选择策略) +│ └── Worker(执行合并) └── Query Builder └── Expression Engine ``` @@ -454,13 +452,14 @@ srdb/ ├── btree.go # B-Tree 索引实现 ├── compaction.go # Compaction 管理器 ├── database.go # 数据库管理 -├── engine.go # 存储引擎核心 +├── errors.go # 错误定义和处理 ├── index.go # 索引管理 +├── index_btree.go # 索引 B+Tree ├── memtable.go # 内存表 ├── query.go # 查询构建器 ├── schema.go # Schema 定义 ├── sstable.go # SSTable 文件 -├── table.go # 表管理 +├── table.go # 表管理(含存储引擎) ├── version.go # 版本管理(MVCC) ├── wal.go # Write-Ahead Log ├── webui/ # Web UI @@ -477,7 +476,7 @@ srdb/ go test ./... # 运行特定测试 -go test -v -run TestEngine +go test -v -run TestTable # 性能测试 go test -bench=. -benchmem @@ -500,7 +499,7 @@ go build -o webui main.go - [设计文档](DESIGN.md) - 详细的架构设计和实现原理 - [WebUI 文档](examples/webui/README.md) - Web 管理界面使用指南 -- [API 文档](https://pkg.go.dev/code.tczkiot.com/srdb) - Go API 参考 +- [API 文档](https://pkg.go.dev/code.tczkiot.com/wlw/srdb) - Go API 参考 --- @@ -541,8 +540,8 @@ MIT License - 详见 [LICENSE](LICENSE) 文件 ## 📧 联系方式 -- 项目主页:https://code.tczkiot.com/srdb -- Issue 跟踪:https://code.tczkiot.com/srdb/issues +- 项目主页:https://code.tczkiot.com/wlw/srdb +- Issue 跟踪:https://code.tczkiot.com/wlw/srdb/issues --- diff --git a/btree.go b/btree.go index 363686d..8a02c43 100644 --- a/btree.go +++ b/btree.go @@ -2,6 +2,7 @@ package srdb import ( "encoding/binary" + "fmt" "os" "slices" "sort" @@ -9,6 +10,88 @@ import ( "github.com/edsrzf/mmap-go" ) +/* +B+Tree 存储格式 + +B+Tree 用于索引 SSTable 和 Index 文件,提供 O(log n) 查询性能。 + +节点结构 (4096 bytes): +┌─────────────────────────────────────────────────────────────┐ +│ Node Header (32 bytes) │ +│ ├─ NodeType (1 byte): 0=Internal, 1=Leaf │ +│ ├─ KeyCount (2 bytes): 节点中的 key 数量 │ +│ ├─ Level (1 byte): 层级 (0=叶子层) │ +│ └─ Reserved (28 bytes): 预留空间 │ +├─────────────────────────────────────────────────────────────┤ +│ Keys Array (variable) │ +│ └─ Key[0..KeyCount-1]: int64 (8 bytes each) │ +├─────────────────────────────────────────────────────────────┤ +│ Values (variable, 取决于节点类型) │ +│ │ +│ 内部节点 (Internal Node): │ +│ └─ Children[0..KeyCount]: int64 (8 bytes each) │ +│ - 子节点的文件偏移量 │ +│ - Children[i] 包含 < Key[i] 的所有 key │ +│ - Children[KeyCount] 包含 >= Key[KeyCount-1] 的 key │ +│ │ +│ 叶子节点 (Leaf Node): │ +│ └─ Data Pairs[0..KeyCount-1]: 交错存储 (12 bytes each) │ +│ ├─ DataOffset: int64 (8 bytes) - 数据块的文件偏移量 │ +│ └─ DataSize: int32 (4 bytes) - 数据块的大小 │ +└─────────────────────────────────────────────────────────────┘ + +节点头格式 (32 bytes): + Offset | Size | Field | Description + -------|------|------------|---------------------------------- + 0 | 1 | NodeType | 0=Internal, 1=Leaf + 1 | 2 | KeyCount | key 数量 (0 ~ BTreeOrder) + 3 | 1 | Level | 层级 (0=叶子层, 1+=内部层) + 4 | 28 | Reserved | 预留空间 + +内部节点布局 (示例: KeyCount=3): + [Header: 32B] + [Keys: Key0(8B), Key1(8B), Key2(8B)] + [Children: Child0(8B), Child1(8B), Child2(8B), Child3(8B)] + + 查询规则: + - key < Key0 → Child0 + - Key0 ≤ key < Key1 → Child1 + - Key1 ≤ key < Key2 → Child2 + - key ≥ Key2 → Child3 + +叶子节点布局 (示例: KeyCount=3): + [Header: 32B] + [Keys: Key0(8B), Key1(8B), Key2(8B)] + [Data: (Offset0, Size0), (Offset1, Size1), (Offset2, Size2)] + - 交错存储: Offset0(8B), Size0(4B), Offset1(8B), Size1(4B), Offset2(8B), Size2(4B) + + 查询规则: + - 找到 key == Key[i] + - 返回 (DataOffsets[i], DataSizes[i]) + +B+Tree 特性: + - 阶数 (Order): 200 (每个节点最多 200 个 key) + - 节点大小: 4096 bytes (4 KB,对齐页大小) + - 高度: log₂₀₀(N) (100万条数据约 3 层) + - 查询复杂度: O(log n) + - 范围查询: 支持(叶子节点有序) + +文件布局示例: + SSTable/Index 文件: + [Header: 256B] + [B+Tree Nodes: 4KB each] + ├─ Root Node (Internal) + ├─ Level 1 Nodes (Internal) + └─ Leaf Nodes + [Data Blocks: variable] + +性能优化: + - mmap 零拷贝: 直接从内存映射读取节点 + - 节点对齐: 4KB 对齐,利用操作系统页缓存 + - 有序存储: 叶子节点有序,支持范围查询 + - 紧凑编码: 最小化节点大小,提高缓存命中率 +*/ + const ( BTreeNodeSize = 4096 // 节点大小 (4 KB) BTreeOrder = 200 // B+Tree 阶数 (保守估计,叶子节点每个entry 20 bytes) @@ -59,6 +142,29 @@ func NewLeafNode() *BTreeNode { } // Marshal 序列化节点到 4 KB +// +// 布局: +// [Header: 32B] +// [Keys: KeyCount * 8B] +// [Values: 取决于节点类型] +// - Internal: Children (KeyCount+1) * 8B +// - Leaf: 交错存储 (Offset, Size) 对,每对 12B,共 KeyCount * 12B +// +// 示例(叶子节点,KeyCount=3): +// Offset | Size | Content +// -------|------|---------------------------------- +// 0 | 1 | NodeType = 1 (Leaf) +// 1 | 2 | KeyCount = 3 +// 3 | 1 | Level = 0 +// 4 | 28 | Reserved +// 32 | 24 | Keys [100, 200, 300] +// 56 | 8 | DataOffset0 = 1000 +// 64 | 4 | DataSize0 = 50 +// 68 | 8 | DataOffset1 = 2000 +// 76 | 4 | DataSize1 = 60 +// 80 | 8 | DataOffset2 = 3000 +// 88 | 4 | DataSize2 = 70 +// 92 | 4004 | Padding (unused) func (n *BTreeNode) Marshal() []byte { buf := make([]byte, BTreeNodeSize) @@ -105,6 +211,16 @@ func (n *BTreeNode) Marshal() []byte { } // UnmarshalBTree 从字节数组反序列化节点 +// +// 参数: +// data: 4KB 节点数据(通常来自 mmap) +// +// 返回: +// *BTreeNode: 反序列化后的节点 +// +// 零拷贝优化: +// - 直接从 mmap 数据读取,不复制整个节点 +// - 只复制必要的字段(Keys, Children, DataOffsets, DataSizes) func UnmarshalBTree(data []byte) *BTreeNode { if len(data) < BTreeNodeSize { return nil @@ -171,25 +287,47 @@ func (n *BTreeNode) AddKey(key int64) { } // AddChild 添加子节点 (仅用于内部节点) -func (n *BTreeNode) AddChild(offset int64) { +func (n *BTreeNode) AddChild(offset int64) error { if n.NodeType != BTreeNodeTypeInternal { - panic("AddChild called on leaf node") + return fmt.Errorf("AddChild called on leaf node") } n.Children = append(n.Children, offset) + return nil } // AddData 添加数据位置 (仅用于叶子节点) -func (n *BTreeNode) AddData(key int64, offset int64, size int32) { +func (n *BTreeNode) AddData(key int64, offset int64, size int32) error { if n.NodeType != BTreeNodeTypeLeaf { - panic("AddData called on internal node") + return fmt.Errorf("AddData called on internal node") } n.Keys = append(n.Keys, key) n.DataOffsets = append(n.DataOffsets, offset) n.DataSizes = append(n.DataSizes, size) n.KeyCount = uint16(len(n.Keys)) + return nil } // BTreeBuilder 从下往上构建 B+Tree +// +// 构建流程: +// 1. Add(): 添加所有 (key, offset, size) 到叶子节点 +// - 当叶子节点满时,创建新的叶子节点 +// - 所有叶子节点按 key 有序 +// +// 2. Build(): 从叶子层向上构建 +// - Level 0: 叶子节点(已创建) +// - Level 1: 为叶子节点创建父节点(内部节点) +// - Level 2+: 递归创建更高层级 +// - 最终返回根节点偏移量 +// +// 示例(100 个 key,Order=200): +// - 叶子层: 1 个叶子节点(100 个 key) +// - 根节点: 叶子节点本身 +// +// 示例(500 个 key,Order=200): +// - 叶子层: 3 个叶子节点(200, 200, 100 个 key) +// - Level 1: 1 个内部节点(3 个子节点) +// - 根节点: Level 1 的内部节点 type BTreeBuilder struct { order int // B+Tree 阶数 file *os.File // 输出文件 @@ -220,7 +358,9 @@ func (b *BTreeBuilder) Add(key int64, dataOffset int64, dataSize int32) error { } // 添加到叶子节点 - leaf.AddData(key, dataOffset, dataSize) + if err := leaf.AddData(key, dataOffset, dataSize); err != nil { + return err + } return nil } @@ -280,14 +420,18 @@ func (b *BTreeBuilder) buildLevel(children []*BTreeNode, childOffsets []int64, l parent := NewInternalNode(byte(level)) // 添加第一个子节点 (没有对应的 key) - parent.AddChild(childOffsets[i]) + if err := parent.AddChild(childOffsets[i]); err != nil { + return nil, nil, err + } // 添加剩余的子节点和分隔 key for j := i + 1; j < end; j++ { // 分隔 key 是子节点的第一个 key separatorKey := children[j].Keys[0] parent.AddKey(separatorKey) - parent.AddChild(childOffsets[j]) + if err := parent.AddChild(childOffsets[j]); err != nil { + return nil, nil, err + } } // 写入父节点 @@ -307,6 +451,24 @@ func (b *BTreeBuilder) buildLevel(children []*BTreeNode, childOffsets []int64, l } // BTreeReader 用于查询 B+Tree (mmap) +// +// 查询流程: +// 1. 从根节点开始 +// 2. 如果是内部节点: +// - 二分查找确定子节点 +// - 跳转到子节点继续查找 +// 3. 如果是叶子节点: +// - 二分查找 key +// - 返回 (dataOffset, dataSize) +// +// 性能优化: +// - mmap 零拷贝:直接从内存映射读取节点 +// - 二分查找:O(log KeyCount) 在节点内查找 +// - 总复杂度:O(log n) = O(height * log Order) +// +// 示例(100万条数据,Order=200): +// - 高度: log₂₀₀(1000000) ≈ 3 +// - 查询次数: 3 次节点读取 + 3 次二分查找 type BTreeReader struct { mmap mmap.MMap rootOffset int64 @@ -321,6 +483,19 @@ func NewBTreeReader(mmap mmap.MMap, rootOffset int64) *BTreeReader { } // Get 查询 key,返回数据位置 +// +// 参数: +// key: 要查询的 key +// +// 返回: +// dataOffset: 数据块的文件偏移量 +// dataSize: 数据块的大小 +// found: 是否找到 +// +// 查询流程: +// 1. 从根节点开始遍历 +// 2. 内部节点:二分查找确定子节点,跳转 +// 3. 叶子节点:二分查找 key,返回数据位置 func (r *BTreeReader) Get(key int64) (dataOffset int64, dataSize int32, found bool) { if r.rootOffset == 0 { return 0, 0, false diff --git a/btree_test.go b/btree_test.go index 16dcb0e..7f9b44d 100644 --- a/btree_test.go +++ b/btree_test.go @@ -87,9 +87,15 @@ func TestBTree(t *testing.T) { func TestBTreeSerialization(t *testing.T) { // 测试节点序列化 leaf := NewLeafNode() - leaf.AddData(1, 1000, 100) - leaf.AddData(2, 2000, 200) - leaf.AddData(3, 3000, 300) + if err := leaf.AddData(1, 1000, 100); err != nil { + t.Fatal(err) + } + if err := leaf.AddData(2, 2000, 200); err != nil { + t.Fatal(err) + } + if err := leaf.AddData(3, 3000, 300); err != nil { + t.Fatal(err) + } // 序列化 data := leaf.Marshal() diff --git a/compaction_test.go b/compaction_test.go index 95344fc..aa9950a 100644 --- a/compaction_test.go +++ b/compaction_test.go @@ -19,6 +19,11 @@ func TestCompactionBasic(t *testing.T) { t.Fatal(err) } + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeInt64}, + }) + // 创建 VersionSet versionSet, err := NewVersionSet(manifestDir) if err != nil { @@ -33,6 +38,9 @@ func TestCompactionBasic(t *testing.T) { } defer sstMgr.Close() + // 设置 Schema + sstMgr.SetSchema(schema) + // 创建测试数据 rows1 := make([]*SSTableRow, 100) for i := range 100 { @@ -75,6 +83,7 @@ func TestCompactionBasic(t *testing.T) { // 创建 Compaction Manager compactionMgr := NewCompactionManager(sstDir, versionSet, sstMgr) + compactionMgr.SetSchema(schema) // 创建更多文件触发 Compaction for i := 1; i < 5; i++ { @@ -204,6 +213,11 @@ func TestCompactionMerge(t *testing.T) { t.Fatal(err) } + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeString}, + }) + // 创建 VersionSet versionSet, err := NewVersionSet(manifestDir) if err != nil { @@ -218,6 +232,9 @@ func TestCompactionMerge(t *testing.T) { } defer sstMgr.Close() + // 设置 Schema + sstMgr.SetSchema(schema) + // 创建两个有重叠 key 的 SST 文件 rows1 := []*SSTableRow{ {Seq: 1, Time: 1000, Data: map[string]any{"value": "old"}}, @@ -269,6 +286,7 @@ func TestCompactionMerge(t *testing.T) { // 创建 Compactor compactor := NewCompactor(sstDir, versionSet) + compactor.SetSchema(schema) // 创建 Compaction 任务 version := versionSet.GetCurrent() @@ -316,6 +334,11 @@ func BenchmarkCompaction(b *testing.B) { b.Fatal(err) } + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeString}, + }) + // 创建 VersionSet versionSet, err := NewVersionSet(manifestDir) if err != nil { @@ -330,6 +353,9 @@ func BenchmarkCompaction(b *testing.B) { } defer sstMgr.Close() + // 设置 Schema + sstMgr.SetSchema(schema) + // 创建测试数据 const numFiles = 5 const rowsPerFile = 1000 @@ -372,6 +398,7 @@ func BenchmarkCompaction(b *testing.B) { // 创建 Compactor compactor := NewCompactor(sstDir, versionSet) + compactor.SetSchema(schema) version := versionSet.GetCurrent() task := &CompactionTask{ @@ -401,16 +428,16 @@ func TestCompactionQueryOrder(t *testing.T) { {Name: "timestamp", Type: FieldTypeInt64}, }) - // 打开 Engine (使用较小的 MemTable 触发频繁 flush) - engine, err := OpenEngine(&EngineOptions{ + // 打开 Table (使用较小的 MemTable 触发频繁 flush) + table, err := OpenTable(&TableOptions{ Dir: tmpDir, MemTableSize: 2 * 1024 * 1024, // 2MB MemTable - Schema: schema, + Name: schema.Name, Fields: schema.Fields, }) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() t.Logf("开始插入 4000 条数据...") @@ -423,7 +450,7 @@ func TestCompactionQueryOrder(t *testing.T) { largeData[j] = byte('A' + (j % 26)) } - err := engine.Insert(map[string]any{ + err := table.Insert(map[string]any{ "id": int64(i), "name": fmt.Sprintf("user_%d", i), "data": string(largeData), @@ -447,7 +474,7 @@ func TestCompactionQueryOrder(t *testing.T) { t.Logf("开始查询所有数据...") // 查询所有数据 - rows, err := engine.Query().Rows() + rows, err := table.Query().Rows() if err != nil { t.Fatal(err) } @@ -514,7 +541,7 @@ func TestCompactionQueryOrder(t *testing.T) { t.Logf("✓ 所有数据完整性验证通过") // 输出 compaction 统计信息 - stats := engine.GetCompactionManager().GetLevelStats() + stats := table.GetCompactionManager().GetLevelStats() t.Logf("Compaction 统计:") for _, levelStat := range stats { level := levelStat["level"].(int) diff --git a/database.go b/database.go index a36a502..384e1dc 100644 --- a/database.go +++ b/database.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "sync" + "time" ) // Database 数据库,管理多个表 @@ -108,8 +109,11 @@ func (db *Database) recoverTables() error { var failedTables []string for _, tableInfo := range db.metadata.Tables { - // FIXME: 是否需要校验 tableInfo.Dir ? - table, err := openTable(tableInfo.Name, db) + tableDir := filepath.Join(db.dir, tableInfo.Name) + table, err := OpenTable(&TableOptions{ + Dir: tableDir, + MemTableSize: DefaultMemTableSize, + }) if err != nil { // 记录失败的表,但继续恢复其他表 failedTables = append(failedTables, tableInfo.Name) @@ -137,12 +141,25 @@ func (db *Database) CreateTable(name string, schema *Schema) (*Table, error) { // 检查表是否已存在 if _, exists := db.tables[name]; exists { - return nil, fmt.Errorf("table %s already exists", name) + return nil, NewErrorf(ErrCodeTableExists, "table %s already exists", name) + } + + // 创建表目录 + tableDir := filepath.Join(db.dir, name) + err := os.MkdirAll(tableDir, 0755) + if err != nil { + return nil, err } // 创建表 - table, err := createTable(name, schema, db) + table, err := OpenTable(&TableOptions{ + Dir: tableDir, + MemTableSize: DefaultMemTableSize, + Name: schema.Name, + Fields: schema.Fields, + }) if err != nil { + os.RemoveAll(tableDir) return nil, err } @@ -153,7 +170,7 @@ func (db *Database) CreateTable(name string, schema *Schema) (*Table, error) { db.metadata.Tables = append(db.metadata.Tables, TableInfo{ Name: name, Dir: name, - CreatedAt: table.createdAt, + CreatedAt: time.Now().Unix(), }) err = db.saveMetadata() @@ -171,7 +188,7 @@ func (db *Database) GetTable(name string) (*Table, error) { table, exists := db.tables[name] if !exists { - return nil, fmt.Errorf("table %s not found", name) + return nil, NewErrorf(ErrCodeTableNotFound, "table %s not found", name) } return table, nil @@ -185,7 +202,7 @@ func (db *Database) DropTable(name string) error { // 检查表是否存在 table, exists := db.tables[name] if !exists { - return fmt.Errorf("table %s not found", name) + return NewErrorf(ErrCodeTableNotFound, "table %s not found", name) } // 关闭表 diff --git a/database_clean_test.go b/database_clean_test.go deleted file mode 100644 index 06aab1e..0000000 --- a/database_clean_test.go +++ /dev/null @@ -1,296 +0,0 @@ -package srdb - -import ( - "fmt" - "os" - "testing" -) - -func TestDatabaseClean(t *testing.T) { - dir := "./test_db_clean_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - - // 2. 创建多个表并插入数据 - // 表 1: users - usersSchema := NewSchema("users", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "User ID"}, - {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, - }) - usersTable, err := db.CreateTable("users", usersSchema) - if err != nil { - t.Fatal(err) - } - - for i := 0; i < 50; i++ { - usersTable.Insert(map[string]any{ - "id": int64(i), - "name": "user" + string(rune(i)), - }) - } - - // 表 2: orders - ordersSchema := NewSchema("orders", []Field{ - {Name: "order_id", Type: FieldTypeInt64, Indexed: true, Comment: "Order ID"}, - {Name: "amount", Type: FieldTypeInt64, Indexed: false, Comment: "Amount"}, - }) - ordersTable, err := db.CreateTable("orders", ordersSchema) - if err != nil { - t.Fatal(err) - } - - for i := 0; i < 30; i++ { - ordersTable.Insert(map[string]any{ - "order_id": int64(i), - "amount": int64(i * 100), - }) - } - - // 3. 验证数据存在 - usersStats := usersTable.Stats() - ordersStats := ordersTable.Stats() - t.Logf("Before Clean - Users: %d rows, Orders: %d rows", - usersStats.TotalRows, ordersStats.TotalRows) - - if usersStats.TotalRows == 0 || ordersStats.TotalRows == 0 { - t.Error("Expected data in tables") - } - - // 4. 清除所有表的数据 - err = db.Clean() - if err != nil { - t.Fatal(err) - } - - // 5. 验证数据已清除 - usersStats = usersTable.Stats() - ordersStats = ordersTable.Stats() - t.Logf("After Clean - Users: %d rows, Orders: %d rows", - usersStats.TotalRows, ordersStats.TotalRows) - - if usersStats.TotalRows != 0 { - t.Errorf("Expected 0 rows in users, got %d", usersStats.TotalRows) - } - if ordersStats.TotalRows != 0 { - t.Errorf("Expected 0 rows in orders, got %d", ordersStats.TotalRows) - } - - // 6. 验证表结构仍然存在 - tables := db.ListTables() - if len(tables) != 2 { - t.Errorf("Expected 2 tables, got %d", len(tables)) - } - - // 7. 验证可以继续插入数据 - err = usersTable.Insert(map[string]any{ - "id": int64(100), - "name": "new_user", - }) - if err != nil { - t.Fatal(err) - } - - usersStats = usersTable.Stats() - if usersStats.TotalRows != 1 { - t.Errorf("Expected 1 row after insert, got %d", usersStats.TotalRows) - } - - db.Close() -} - -func TestDatabaseDestroy(t *testing.T) { - dir := "./test_db_destroy_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - }) - table, err := db.CreateTable("test", schema) - if err != nil { - t.Fatal(err) - } - - // 插入数据 - for i := 0; i < 20; i++ { - table.Insert(map[string]any{"id": int64(i)}) - } - - // 2. 验证数据存在 - stats := table.Stats() - t.Logf("Before Destroy: %d rows", stats.TotalRows) - - if stats.TotalRows == 0 { - t.Error("Expected data in table") - } - - // 3. 销毁数据库 - err = db.Destroy() - if err != nil { - t.Fatal(err) - } - - // 4. 验证数据目录已删除 - if _, err := os.Stat(dir); !os.IsNotExist(err) { - t.Error("Database directory should be deleted") - } - - // 5. 验证数据库不可用 - tables := db.ListTables() - if len(tables) != 0 { - t.Errorf("Expected 0 tables after destroy, got %d", len(tables)) - } -} - -func TestDatabaseCleanMultipleTables(t *testing.T) { - dir := "./test_db_clean_multi_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和多个表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - // 创建 5 个表 - for i := 0; i < 5; i++ { - tableName := fmt.Sprintf("table%d", i) - schema := NewSchema(tableName, []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - {Name: "value", Type: FieldTypeString, Indexed: false, Comment: "Value"}, - }) - - table, err := db.CreateTable(tableName, schema) - if err != nil { - t.Fatal(err) - } - - // 每个表插入 10 条数据 - for j := 0; j < 10; j++ { - table.Insert(map[string]any{ - "id": int64(j), - "value": fmt.Sprintf("value_%d_%d", i, j), - }) - } - } - - // 2. 验证所有表都有数据 - tables := db.ListTables() - if len(tables) != 5 { - t.Fatalf("Expected 5 tables, got %d", len(tables)) - } - - totalRows := 0 - for _, tableName := range tables { - table, _ := db.GetTable(tableName) - stats := table.Stats() - totalRows += int(stats.TotalRows) - } - t.Logf("Total rows before clean: %d", totalRows) - - if totalRows == 0 { - t.Error("Expected data in tables") - } - - // 3. 清除所有表 - err = db.Clean() - if err != nil { - t.Fatal(err) - } - - // 4. 验证所有表数据已清除 - totalRows = 0 - for _, tableName := range tables { - table, _ := db.GetTable(tableName) - stats := table.Stats() - totalRows += int(stats.TotalRows) - - if stats.TotalRows != 0 { - t.Errorf("Table %s should have 0 rows, got %d", tableName, stats.TotalRows) - } - } - t.Logf("Total rows after clean: %d", totalRows) - - // 5. 验证表结构仍然存在 - tables = db.ListTables() - if len(tables) != 5 { - t.Errorf("Expected 5 tables after clean, got %d", len(tables)) - } -} - -func TestDatabaseCleanAndReopen(t *testing.T) { - dir := "./test_db_clean_reopen_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - }) - table, err := db.CreateTable("test", schema) - if err != nil { - t.Fatal(err) - } - - // 插入数据 - for i := 0; i < 50; i++ { - table.Insert(map[string]any{"id": int64(i)}) - } - - // 2. 清除数据 - err = db.Clean() - if err != nil { - t.Fatal(err) - } - - // 3. 关闭并重新打开 - db.Close() - - db2, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db2.Close() - - // 4. 验证表存在但数据为空 - tables := db2.ListTables() - if len(tables) != 1 { - t.Errorf("Expected 1 table, got %d", len(tables)) - } - - table2, err := db2.GetTable("test") - if err != nil { - t.Fatal(err) - } - - stats := table2.Stats() - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after reopen, got %d", stats.TotalRows) - } - - // 5. 验证可以插入新数据 - err = table2.Insert(map[string]any{"id": int64(100)}) - if err != nil { - t.Fatal(err) - } - - stats = table2.Stats() - if stats.TotalRows != 1 { - t.Errorf("Expected 1 row, got %d", stats.TotalRows) - } -} diff --git a/database_table_ops_test.go b/database_table_ops_test.go deleted file mode 100644 index 6d2357b..0000000 --- a/database_table_ops_test.go +++ /dev/null @@ -1,195 +0,0 @@ -package srdb - -import ( - "os" - "testing" -) - -func TestDatabaseCleanTable(t *testing.T) { - dir := "./test_db_clean_table_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("users", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, - }) - - table, err := db.CreateTable("users", schema) - if err != nil { - t.Fatal(err) - } - - // 2. 插入数据 - for i := 0; i < 50; i++ { - table.Insert(map[string]any{ - "id": int64(i), - "name": "user", - }) - } - - // 3. 验证数据存在 - stats := table.Stats() - if stats.TotalRows == 0 { - t.Error("Expected data in table") - } - - // 4. 清除表数据 - err = db.CleanTable("users") - if err != nil { - t.Fatal(err) - } - - // 5. 验证数据已清除 - stats = table.Stats() - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) - } - - // 6. 验证表仍然存在 - tables := db.ListTables() - found := false - for _, name := range tables { - if name == "users" { - found = true - break - } - } - if !found { - t.Error("Table should still exist after clean") - } - - // 7. 验证可以继续插入 - err = table.Insert(map[string]any{ - "id": int64(100), - "name": "new_user", - }) - if err != nil { - t.Fatal(err) - } - - stats = table.Stats() - if stats.TotalRows != 1 { - t.Errorf("Expected 1 row, got %d", stats.TotalRows) - } -} - -func TestDatabaseDestroyTable(t *testing.T) { - dir := "./test_db_destroy_table_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - }) - - table, err := db.CreateTable("test", schema) - if err != nil { - t.Fatal(err) - } - - // 2. 插入数据 - for i := 0; i < 30; i++ { - table.Insert(map[string]any{"id": int64(i)}) - } - - // 3. 验证数据存在 - stats := table.Stats() - if stats.TotalRows == 0 { - t.Error("Expected data in table") - } - - // 4. 销毁表 - err = db.DestroyTable("test") - if err != nil { - t.Fatal(err) - } - - // 5. 验证表已从 Database 中删除 - tables := db.ListTables() - for _, name := range tables { - if name == "test" { - t.Error("Table should be removed from database") - } - } - - // 6. 验证无法再获取该表 - _, err = db.GetTable("test") - if err == nil { - t.Error("Should not be able to get table after destroy") - } -} - -func TestDatabaseDestroyTableMultiple(t *testing.T) { - dir := "./test_db_destroy_multi_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和多个表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - }) - - // 创建 3 个表 - for i := 1; i <= 3; i++ { - tableName := "table" + string(rune('0'+i)) - _, err := db.CreateTable(tableName, schema) - if err != nil { - t.Fatal(err) - } - } - - // 2. 验证有 3 个表 - tables := db.ListTables() - if len(tables) != 3 { - t.Fatalf("Expected 3 tables, got %d", len(tables)) - } - - // 3. 销毁中间的表 - err = db.DestroyTable("table2") - if err != nil { - t.Fatal(err) - } - - // 4. 验证只剩 2 个表 - tables = db.ListTables() - if len(tables) != 2 { - t.Errorf("Expected 2 tables, got %d", len(tables)) - } - - // 5. 验证剩余的表是正确的 - hasTable1 := false - hasTable3 := false - for _, name := range tables { - if name == "table1" { - hasTable1 = true - } - if name == "table3" { - hasTable3 = true - } - if name == "table2" { - t.Error("table2 should be destroyed") - } - } - - if !hasTable1 || !hasTable3 { - t.Error("table1 and table3 should still exist") - } -} diff --git a/database_test.go b/database_test.go index fb0c4f5..3506315 100644 --- a/database_test.go +++ b/database_test.go @@ -1,7 +1,9 @@ package srdb import ( + "fmt" "os" + "slices" "testing" ) @@ -257,3 +259,475 @@ func TestDatabaseRecover(t *testing.T) { t.Log("Database recover test passed!") } + +func TestDatabaseClean(t *testing.T) { + dir := "./test_db_clean_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + // 2. 创建多个表并插入数据 + // 表 1: users + usersSchema := NewSchema("users", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "User ID"}, + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, + }) + usersTable, err := db.CreateTable("users", usersSchema) + if err != nil { + t.Fatal(err) + } + + for i := range 50 { + usersTable.Insert(map[string]any{ + "id": int64(i), + "name": "user" + string(rune(i)), + }) + } + + // 表 2: orders + ordersSchema := NewSchema("orders", []Field{ + {Name: "order_id", Type: FieldTypeInt64, Indexed: true, Comment: "Order ID"}, + {Name: "amount", Type: FieldTypeInt64, Indexed: false, Comment: "Amount"}, + }) + ordersTable, err := db.CreateTable("orders", ordersSchema) + if err != nil { + t.Fatal(err) + } + + for i := range 30 { + ordersTable.Insert(map[string]any{ + "order_id": int64(i), + "amount": int64(i * 100), + }) + } + + // 3. 验证数据存在 + usersStats := usersTable.Stats() + ordersStats := ordersTable.Stats() + t.Logf("Before Clean - Users: %d rows, Orders: %d rows", + usersStats.TotalRows, ordersStats.TotalRows) + + if usersStats.TotalRows == 0 || ordersStats.TotalRows == 0 { + t.Error("Expected data in tables") + } + + // 4. 清除所有表的数据 + err = db.Clean() + if err != nil { + t.Fatal(err) + } + + // 5. 验证数据已清除 + usersStats = usersTable.Stats() + ordersStats = ordersTable.Stats() + t.Logf("After Clean - Users: %d rows, Orders: %d rows", + usersStats.TotalRows, ordersStats.TotalRows) + + if usersStats.TotalRows != 0 { + t.Errorf("Expected 0 rows in users, got %d", usersStats.TotalRows) + } + if ordersStats.TotalRows != 0 { + t.Errorf("Expected 0 rows in orders, got %d", ordersStats.TotalRows) + } + + // 6. 验证表结构仍然存在 + tables := db.ListTables() + if len(tables) != 2 { + t.Errorf("Expected 2 tables, got %d", len(tables)) + } + + // 7. 验证可以继续插入数据 + err = usersTable.Insert(map[string]any{ + "id": int64(100), + "name": "new_user", + }) + if err != nil { + t.Fatal(err) + } + + usersStats = usersTable.Stats() + if usersStats.TotalRows != 1 { + t.Errorf("Expected 1 row after insert, got %d", usersStats.TotalRows) + } + + db.Close() +} + +func TestDatabaseDestroy(t *testing.T) { + dir := "./test_db_destroy_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + }) + table, err := db.CreateTable("test", schema) + if err != nil { + t.Fatal(err) + } + + // 插入数据 + for i := range 20 { + table.Insert(map[string]any{"id": int64(i)}) + } + + // 2. 验证数据存在 + stats := table.Stats() + t.Logf("Before Destroy: %d rows", stats.TotalRows) + + if stats.TotalRows == 0 { + t.Error("Expected data in table") + } + + // 3. 销毁数据库 + err = db.Destroy() + if err != nil { + t.Fatal(err) + } + + // 4. 验证数据目录已删除 + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Error("Database directory should be deleted") + } + + // 5. 验证数据库不可用 + tables := db.ListTables() + if len(tables) != 0 { + t.Errorf("Expected 0 tables after destroy, got %d", len(tables)) + } +} + +func TestDatabaseCleanMultipleTables(t *testing.T) { + dir := "./test_db_clean_multi_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和多个表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // 创建 5 个表 + for i := range 5 { + tableName := fmt.Sprintf("table%d", i) + schema := NewSchema(tableName, []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + {Name: "value", Type: FieldTypeString, Indexed: false, Comment: "Value"}, + }) + + table, err := db.CreateTable(tableName, schema) + if err != nil { + t.Fatal(err) + } + + // 每个表插入 10 条数据 + for j := range 10 { + table.Insert(map[string]any{ + "id": int64(j), + "value": fmt.Sprintf("value_%d_%d", i, j), + }) + } + } + + // 2. 验证所有表都有数据 + tables := db.ListTables() + if len(tables) != 5 { + t.Fatalf("Expected 5 tables, got %d", len(tables)) + } + + totalRows := 0 + for _, tableName := range tables { + table, _ := db.GetTable(tableName) + stats := table.Stats() + totalRows += int(stats.TotalRows) + } + t.Logf("Total rows before clean: %d", totalRows) + + if totalRows == 0 { + t.Error("Expected data in tables") + } + + // 3. 清除所有表 + err = db.Clean() + if err != nil { + t.Fatal(err) + } + + // 4. 验证所有表数据已清除 + totalRows = 0 + for _, tableName := range tables { + table, _ := db.GetTable(tableName) + stats := table.Stats() + totalRows += int(stats.TotalRows) + + if stats.TotalRows != 0 { + t.Errorf("Table %s should have 0 rows, got %d", tableName, stats.TotalRows) + } + } + t.Logf("Total rows after clean: %d", totalRows) + + // 5. 验证表结构仍然存在 + tables = db.ListTables() + if len(tables) != 5 { + t.Errorf("Expected 5 tables after clean, got %d", len(tables)) + } +} + +func TestDatabaseCleanAndReopen(t *testing.T) { + dir := "./test_db_clean_reopen_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + }) + table, err := db.CreateTable("test", schema) + if err != nil { + t.Fatal(err) + } + + // 插入数据 + for i := range 50 { + table.Insert(map[string]any{"id": int64(i)}) + } + + // 2. 清除数据 + err = db.Clean() + if err != nil { + t.Fatal(err) + } + + // 3. 关闭并重新打开 + db.Close() + + db2, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db2.Close() + + // 4. 验证表存在但数据为空 + tables := db2.ListTables() + if len(tables) != 1 { + t.Errorf("Expected 1 table, got %d", len(tables)) + } + + table2, err := db2.GetTable("test") + if err != nil { + t.Fatal(err) + } + + stats := table2.Stats() + if stats.TotalRows != 0 { + t.Errorf("Expected 0 rows after reopen, got %d", stats.TotalRows) + } + + // 5. 验证可以插入新数据 + err = table2.Insert(map[string]any{"id": int64(100)}) + if err != nil { + t.Fatal(err) + } + + stats = table2.Stats() + if stats.TotalRows != 1 { + t.Errorf("Expected 1 row, got %d", stats.TotalRows) + } +} + +func TestDatabaseCleanTable(t *testing.T) { + dir := "./test_db_clean_table_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("users", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, + }) + + table, err := db.CreateTable("users", schema) + if err != nil { + t.Fatal(err) + } + + // 2. 插入数据 + for i := range 50 { + table.Insert(map[string]any{ + "id": int64(i), + "name": "user", + }) + } + + // 3. 验证数据存在 + stats := table.Stats() + if stats.TotalRows == 0 { + t.Error("Expected data in table") + } + + // 4. 清除表数据 + err = db.CleanTable("users") + if err != nil { + t.Fatal(err) + } + + // 5. 验证数据已清除 + stats = table.Stats() + if stats.TotalRows != 0 { + t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) + } + + // 6. 验证表仍然存在 + tables := db.ListTables() + found := slices.Contains(tables, "users") + if !found { + t.Error("Table should still exist after clean") + } + + // 7. 验证可以继续插入 + err = table.Insert(map[string]any{ + "id": int64(100), + "name": "new_user", + }) + if err != nil { + t.Fatal(err) + } + + stats = table.Stats() + if stats.TotalRows != 1 { + t.Errorf("Expected 1 row, got %d", stats.TotalRows) + } +} + +func TestDatabaseDestroyTable(t *testing.T) { + dir := "./test_db_destroy_table_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + }) + + table, err := db.CreateTable("test", schema) + if err != nil { + t.Fatal(err) + } + + // 2. 插入数据 + for i := range 30 { + table.Insert(map[string]any{"id": int64(i)}) + } + + // 3. 验证数据存在 + stats := table.Stats() + if stats.TotalRows == 0 { + t.Error("Expected data in table") + } + + // 4. 销毁表 + err = db.DestroyTable("test") + if err != nil { + t.Fatal(err) + } + + // 5. 验证表已从 Database 中删除 + tables := db.ListTables() + for _, name := range tables { + if name == "test" { + t.Error("Table should be removed from database") + } + } + + // 6. 验证无法再获取该表 + _, err = db.GetTable("test") + if err == nil { + t.Error("Should not be able to get table after destroy") + } +} + +func TestDatabaseDestroyTableMultiple(t *testing.T) { + dir := "./test_db_destroy_multi_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和多个表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + }) + + // 创建 3 个表 + for i := 1; i <= 3; i++ { + tableName := "table" + string(rune('0'+i)) + _, err := db.CreateTable(tableName, schema) + if err != nil { + t.Fatal(err) + } + } + + // 2. 验证有 3 个表 + tables := db.ListTables() + if len(tables) != 3 { + t.Fatalf("Expected 3 tables, got %d", len(tables)) + } + + // 3. 销毁中间的表 + err = db.DestroyTable("table2") + if err != nil { + t.Fatal(err) + } + + // 4. 验证只剩 2 个表 + tables = db.ListTables() + if len(tables) != 2 { + t.Errorf("Expected 2 tables, got %d", len(tables)) + } + + // 5. 验证剩余的表是正确的 + hasTable1 := false + hasTable3 := false + for _, name := range tables { + if name == "table1" { + hasTable1 = true + } + if name == "table3" { + hasTable3 = true + } + if name == "table2" { + t.Error("table2 should be destroyed") + } + } + + if !hasTable1 || !hasTable3 { + t.Error("table1 and table3 should still exist") + } +} diff --git a/engine.go b/engine.go deleted file mode 100644 index 7239fc4..0000000 --- a/engine.go +++ /dev/null @@ -1,814 +0,0 @@ -package srdb - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "sort" - "sync" - "sync/atomic" - "time" -) - -const ( - DefaultMemTableSize = 64 * 1024 * 1024 // 64 MB - DefaultAutoFlushTimeout = 30 * time.Second // 30 秒无写入自动 flush -) - -// Engine 存储引擎 -type Engine struct { - dir string - schema *Schema - indexManager *IndexManager - walManager *WALManager // WAL 管理器 - sstManager *SSTableManager // SST 管理器 - memtableManager *MemTableManager // MemTable 管理器 - versionSet *VersionSet // MANIFEST 管理器 - compactionManager *CompactionManager // Compaction 管理器 - seq atomic.Int64 - flushMu sync.Mutex - - // 自动 flush 相关 - autoFlushTimeout time.Duration - lastWriteTime atomic.Int64 // 最后写入时间(UnixNano) - stopAutoFlush chan struct{} -} - -// EngineOptions 配置选项 -type EngineOptions struct { - Dir string - MemTableSize int64 - Schema *Schema // 可选的 Schema 定义 - AutoFlushTimeout time.Duration // 自动 flush 超时时间,0 表示禁用 -} - -// OpenEngine 打开数据库 -func OpenEngine(opts *EngineOptions) (*Engine, error) { - if opts.MemTableSize == 0 { - opts.MemTableSize = DefaultMemTableSize - } - - // 创建主目录 - err := os.MkdirAll(opts.Dir, 0755) - if err != nil { - return nil, err - } - - // 创建子目录 - walDir := filepath.Join(opts.Dir, "wal") - sstDir := filepath.Join(opts.Dir, "sst") - idxDir := filepath.Join(opts.Dir, "index") - - err = os.MkdirAll(walDir, 0755) - if err != nil { - return nil, err - } - err = os.MkdirAll(sstDir, 0755) - if err != nil { - return nil, err - } - err = os.MkdirAll(idxDir, 0755) - if err != nil { - return nil, err - } - - // 尝试从磁盘恢复 Schema(如果 Options 中没有提供) - var sch *Schema - if opts.Schema != nil { - // 使用提供的 Schema - sch = opts.Schema - // 保存到磁盘(带校验和) - schemaPath := filepath.Join(opts.Dir, "schema.json") - schemaFile, err := NewSchemaFile(sch) - if err != nil { - return nil, fmt.Errorf("create schema file: %w", err) - } - schemaData, err := json.MarshalIndent(schemaFile, "", " ") - if err != nil { - return nil, fmt.Errorf("marshal schema: %w", err) - } - err = os.WriteFile(schemaPath, schemaData, 0644) - if err != nil { - return nil, fmt.Errorf("write schema: %w", err) - } - } else { - // 尝试从磁盘恢复 - schemaPath := filepath.Join(opts.Dir, "schema.json") - schemaData, err := os.ReadFile(schemaPath) - if err == nil { - // 文件存在,尝试解析 - schemaFile := &SchemaFile{} - err = json.Unmarshal(schemaData, schemaFile) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal schema from %s: %w", schemaPath, err) - } - - // 验证校验和 - err = schemaFile.Verify() - if err != nil { - return nil, fmt.Errorf("failed to verify schema from %s: %w", schemaPath, err) - } - - sch = schemaFile.Schema - } else if !os.IsNotExist(err) { - // 其他读取错误 - return nil, fmt.Errorf("failed to read schema file %s: %w", schemaPath, err) - } - // 如果文件不存在,sch 保持为 nil(可选) - } - - // 创建索引管理器 - var indexMgr *IndexManager - if sch != nil { - indexMgr = NewIndexManager(idxDir, sch) - } - - // 创建 SST Manager - sstMgr, err := NewSSTableManager(sstDir) - if err != nil { - return nil, err - } - - // 设置 Schema(用于优化编解码) - if sch != nil { - sstMgr.SetSchema(sch) - } - - // 创建 MemTable Manager - memMgr := NewMemTableManager(opts.MemTableSize) - - // 创建/恢复 MANIFEST - manifestDir := opts.Dir - versionSet, err := NewVersionSet(manifestDir) - if err != nil { - return nil, fmt.Errorf("create version set: %w", err) - } - - // 创建 Engine(暂时不设置 WAL Manager) - engine := &Engine{ - dir: opts.Dir, - schema: sch, - indexManager: indexMgr, - walManager: nil, // 先不设置,恢复后再创建 - sstManager: sstMgr, - memtableManager: memMgr, - versionSet: versionSet, - } - - // 先恢复数据(包括从 WAL 恢复) - err = engine.recover() - if err != nil { - return nil, err - } - - // 恢复完成后,创建 WAL Manager 用于后续写入 - walMgr, err := NewWALManager(walDir) - if err != nil { - return nil, err - } - engine.walManager = walMgr - engine.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber()) - - // 创建 Compaction Manager - engine.compactionManager = NewCompactionManager(sstDir, versionSet, sstMgr) - - // 设置 Schema(如果有) - if sch != nil { - engine.compactionManager.SetSchema(sch) - } - - // 启动时清理孤儿文件(崩溃恢复后的清理) - engine.compactionManager.CleanupOrphanFiles() - - // 启动后台 Compaction 和垃圾回收 - engine.compactionManager.Start() - - // 验证并修复索引 - if engine.indexManager != nil { - engine.verifyAndRepairIndexes() - } - - // 设置自动 flush 超时时间 - if opts.AutoFlushTimeout > 0 { - engine.autoFlushTimeout = opts.AutoFlushTimeout - } else { - engine.autoFlushTimeout = DefaultAutoFlushTimeout - } - engine.stopAutoFlush = make(chan struct{}) - engine.lastWriteTime.Store(time.Now().UnixNano()) - - // 启动自动 flush 监控 - go engine.autoFlushMonitor() - - return engine, nil -} - -// Insert 插入数据 -func (e *Engine) Insert(data map[string]any) error { - // 1. 验证 Schema (如果定义了) - if e.schema != nil { - if err := e.schema.Validate(data); err != nil { - return fmt.Errorf("schema validation failed: %v", err) - } - } - - // 2. 生成 _seq - seq := e.seq.Add(1) - - // 2. 添加系统字段 - row := &SSTableRow{ - Seq: seq, - Time: time.Now().UnixNano(), - Data: data, - } - - // 3. 序列化 - rowData, err := json.Marshal(row) - if err != nil { - return err - } - - // 4. 写入 WAL - entry := &WALEntry{ - Type: WALEntryTypePut, - Seq: seq, - Data: rowData, - } - err = e.walManager.Append(entry) - if err != nil { - return err - } - - // 5. 写入 MemTable Manager - e.memtableManager.Put(seq, rowData) - - // 6. 添加到索引 - if e.indexManager != nil { - e.indexManager.AddToIndexes(data, seq) - } - - // 7. 更新最后写入时间 - e.lastWriteTime.Store(time.Now().UnixNano()) - - // 8. 检查是否需要切换 MemTable - if e.memtableManager.ShouldSwitch() { - go e.switchMemTable() - } - - return nil -} - -// Get 查询数据 -func (e *Engine) Get(seq int64) (*SSTableRow, error) { - // 1. 先查 MemTable Manager (Active + Immutables) - data, found := e.memtableManager.Get(seq) - if found { - var row SSTableRow - err := json.Unmarshal(data, &row) - if err != nil { - return nil, err - } - return &row, nil - } - - // 2. 查询 SST 文件 - return e.sstManager.Get(seq) -} - -// GetPartial 按需查询数据(只读取指定字段) -func (e *Engine) GetPartial(seq int64, fields []string) (*SSTableRow, error) { - // 1. 先查 MemTable Manager (Active + Immutables) - data, found := e.memtableManager.Get(seq) - if found { - var row SSTableRow - err := json.Unmarshal(data, &row) - if err != nil { - return nil, err - } - - // MemTable 中的数据已经完全解析,需要手动过滤字段 - if len(fields) > 0 { - filteredData := make(map[string]any) - for _, field := range fields { - if val, ok := row.Data[field]; ok { - filteredData[field] = val - } - } - row.Data = filteredData - } - - return &row, nil - } - - // 2. 查询 SST 文件(按需解码) - return e.sstManager.GetPartial(seq, fields) -} - -// switchMemTable 切换 MemTable -func (e *Engine) switchMemTable() error { - e.flushMu.Lock() - defer e.flushMu.Unlock() - - // 1. 切换到新的 WAL - oldWALNumber, err := e.walManager.Rotate() - if err != nil { - return err - } - newWALNumber := e.walManager.GetCurrentNumber() - - // 2. 切换 MemTable (Active → Immutable) - _, immutable := e.memtableManager.Switch(newWALNumber) - - // 3. 异步 Flush Immutable - go e.flushImmutable(immutable, oldWALNumber) - - return nil -} - -// flushImmutable 将 Immutable MemTable 刷新到 SST -func (e *Engine) flushImmutable(imm *ImmutableMemTable, walNumber int64) error { - // 1. 收集所有行 - var rows []*SSTableRow - iter := imm.NewIterator() - for iter.Next() { - var row SSTableRow - err := json.Unmarshal(iter.Value(), &row) - if err == nil { - rows = append(rows, &row) - } - } - - if len(rows) == 0 { - // 没有数据,直接清理 - e.walManager.Delete(walNumber) - e.memtableManager.RemoveImmutable(imm) - return nil - } - - // 2. 从 VersionSet 分配文件编号 - fileNumber := e.versionSet.AllocateFileNumber() - - // 3. 创建 SST 文件到 L0 - reader, err := e.sstManager.CreateSST(fileNumber, rows) - if err != nil { - return err - } - - // 4. 创建 FileMetadata - header := reader.GetHeader() - - // 获取文件大小 - sstPath := reader.GetPath() - fileInfo, err := os.Stat(sstPath) - if err != nil { - return fmt.Errorf("stat sst file: %w", err) - } - - fileMeta := &FileMetadata{ - FileNumber: fileNumber, - Level: 0, // Flush 到 L0 - FileSize: fileInfo.Size(), - MinKey: header.MinKey, - MaxKey: header.MaxKey, - RowCount: header.RowCount, - } - - // 5. 更新 MANIFEST - edit := NewVersionEdit() - edit.AddFile(fileMeta) - - // 持久化当前的文件编号计数器(关键修复:防止重启后文件编号重用) - // 使用 fileNumber + 1 确保并发安全,避免竞态条件 - edit.SetNextFileNumber(fileNumber + 1) - - err = e.versionSet.LogAndApply(edit) - if err != nil { - return fmt.Errorf("log and apply version edit: %w", err) - } - - // 6. 删除对应的 WAL - e.walManager.Delete(walNumber) - - // 7. 从 Immutable 列表中移除 - e.memtableManager.RemoveImmutable(imm) - - // 8. 持久化索引(防止崩溃丢失索引数据) - if e.indexManager != nil { - e.indexManager.BuildAll() - } - - // 9. Compaction 由后台线程负责,不在 flush 路径中触发 - // 避免同步 compaction 导致刚创建的文件立即被删除 - // e.compactionManager.MaybeCompact() - - return nil -} - -// recover 恢复数据 -func (e *Engine) recover() error { - // 1. 恢复 SST 文件(SST Manager 已经在 NewManager 中恢复了) - // 只需要获取最大 seq - maxSeq := e.sstManager.GetMaxSeq() - if maxSeq > e.seq.Load() { - e.seq.Store(maxSeq) - } - - // 2. 恢复所有 WAL 文件到 MemTable Manager - walDir := filepath.Join(e.dir, "wal") - pattern := filepath.Join(walDir, "*.wal") - walFiles, err := filepath.Glob(pattern) - if err == nil && len(walFiles) > 0 { - // 按文件名排序 - sort.Strings(walFiles) - - // 依次读取每个 WAL - for _, walPath := range walFiles { - reader, err := NewWALReader(walPath) - if err != nil { - continue - } - - entries, err := reader.Read() - reader.Close() - - if err != nil { - continue - } - - // 重放 WAL 到 Active MemTable - for _, entry := range entries { - // 如果定义了 Schema,验证数据 - if e.schema != nil { - var row SSTableRow - if err := json.Unmarshal(entry.Data, &row); err != nil { - return fmt.Errorf("failed to unmarshal row during recovery (seq=%d): %w", entry.Seq, err) - } - - // 验证 Schema - if err := e.schema.Validate(row.Data); err != nil { - return fmt.Errorf("schema validation failed during recovery (seq=%d): %w", entry.Seq, err) - } - } - - e.memtableManager.Put(entry.Seq, entry.Data) - if entry.Seq > e.seq.Load() { - e.seq.Store(entry.Seq) - } - } - } - } - - return nil -} - -// autoFlushMonitor 自动 flush 监控 -func (e *Engine) autoFlushMonitor() { - ticker := time.NewTicker(e.autoFlushTimeout / 2) // 每半个超时时间检查一次 - defer ticker.Stop() - - for { - select { - case <-ticker.C: - // 检查是否超时 - lastWrite := time.Unix(0, e.lastWriteTime.Load()) - if time.Since(lastWrite) >= e.autoFlushTimeout { - // 检查 MemTable 是否有数据 - active := e.memtableManager.GetActive() - if active != nil && active.Size() > 0 { - // 触发 flush - e.Flush() - } - } - case <-e.stopAutoFlush: - return - } - } -} - -// Flush 手动刷新 Active MemTable 到磁盘 -func (e *Engine) Flush() error { - // 检查 Active MemTable 是否有数据 - active := e.memtableManager.GetActive() - if active == nil || active.Size() == 0 { - return nil // 没有数据,无需 flush - } - - // 强制切换 MemTable(switchMemTable 内部有锁) - return e.switchMemTable() -} - -// Close 关闭引擎 -func (e *Engine) Close() error { - // 1. 停止自动 flush 监控(如果还在运行) - if e.stopAutoFlush != nil { - select { - case <-e.stopAutoFlush: - // 已经关闭,跳过 - default: - close(e.stopAutoFlush) - } - } - - // 2. 停止 Compaction Manager - if e.compactionManager != nil { - e.compactionManager.Stop() - } - - // 3. 刷新 Active MemTable(确保所有数据都写入磁盘) - // 检查 memtableManager 是否存在(可能已被 Destroy) - if e.memtableManager != nil { - e.Flush() - } - - // 3. 关闭 WAL Manager - if e.walManager != nil { - e.walManager.Close() - } - - // 4. 等待所有 Immutable Flush 完成 - // TODO: 添加更优雅的等待机制 - if e.memtableManager != nil { - for e.memtableManager.GetImmutableCount() > 0 { - time.Sleep(100 * time.Millisecond) - } - } - - // 5. 保存所有索引 - if e.indexManager != nil { - e.indexManager.BuildAll() - e.indexManager.Close() - } - - // 6. 关闭 VersionSet - if e.versionSet != nil { - e.versionSet.Close() - } - - // 7. 关闭 WAL Manager - if e.walManager != nil { - e.walManager.Close() - } - - // 6. 关闭 SST Manager - if e.sstManager != nil { - e.sstManager.Close() - } - - return nil -} - -// Clean 清除所有数据(保留 Engine 可用) -func (e *Engine) Clean() error { - e.flushMu.Lock() - defer e.flushMu.Unlock() - - // 0. 停止自动 flush 监控(临时) - if e.stopAutoFlush != nil { - close(e.stopAutoFlush) - } - - // 1. 停止 Compaction Manager - if e.compactionManager != nil { - e.compactionManager.Stop() - } - - // 2. 等待所有 Immutable Flush 完成 - for e.memtableManager.GetImmutableCount() > 0 { - time.Sleep(100 * time.Millisecond) - } - - // 3. 清空 MemTable - e.memtableManager = NewMemTableManager(DefaultMemTableSize) - - // 2. 删除所有 WAL 文件 - if e.walManager != nil { - e.walManager.Close() - walDir := filepath.Join(e.dir, "wal") - os.RemoveAll(walDir) - os.MkdirAll(walDir, 0755) - - // 重新创建 WAL Manager - walMgr, err := NewWALManager(walDir) - if err != nil { - return fmt.Errorf("recreate wal manager: %w", err) - } - e.walManager = walMgr - e.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber()) - } - - // 3. 删除所有 SST 文件 - if e.sstManager != nil { - e.sstManager.Close() - sstDir := filepath.Join(e.dir, "sst") - os.RemoveAll(sstDir) - os.MkdirAll(sstDir, 0755) - - // 重新创建 SST Manager - sstMgr, err := NewSSTableManager(sstDir) - if err != nil { - return fmt.Errorf("recreate sst manager: %w", err) - } - e.sstManager = sstMgr - } - - // 4. 删除所有索引文件 - if e.indexManager != nil { - e.indexManager.Close() - indexFiles, _ := filepath.Glob(filepath.Join(e.dir, "idx_*.sst")) - for _, f := range indexFiles { - os.Remove(f) - } - - // 重新创建 Index Manager - if e.schema != nil { - e.indexManager = NewIndexManager(e.dir, e.schema) - } - } - - // 5. 重置 MANIFEST - if e.versionSet != nil { - e.versionSet.Close() - manifestDir := e.dir - os.Remove(filepath.Join(manifestDir, "MANIFEST")) - os.Remove(filepath.Join(manifestDir, "CURRENT")) - - // 重新创建 VersionSet - versionSet, err := NewVersionSet(manifestDir) - if err != nil { - return fmt.Errorf("recreate version set: %w", err) - } - e.versionSet = versionSet - } - - // 6. 重新创建 Compaction Manager - sstDir := filepath.Join(e.dir, "sst") - e.compactionManager = NewCompactionManager(sstDir, e.versionSet, e.sstManager) - if e.schema != nil { - e.compactionManager.SetSchema(e.schema) - } - e.compactionManager.Start() - - // 7. 重置序列号 - e.seq.Store(0) - - // 8. 更新最后写入时间 - e.lastWriteTime.Store(time.Now().UnixNano()) - - // 9. 重启自动 flush 监控 - e.stopAutoFlush = make(chan struct{}) - go e.autoFlushMonitor() - - return nil -} - -// Destroy 销毁 Engine 并删除所有数据文件 -func (e *Engine) Destroy() error { - // 1. 先关闭 Engine - if err := e.Close(); err != nil { - return fmt.Errorf("close engine: %w", err) - } - - // 2. 删除整个数据目录 - if err := os.RemoveAll(e.dir); err != nil { - return fmt.Errorf("remove data directory: %w", err) - } - - // 3. 标记 Engine 为不可用(将所有管理器设为 nil) - e.walManager = nil - e.sstManager = nil - e.memtableManager = nil - e.versionSet = nil - e.compactionManager = nil - e.indexManager = nil - - return nil -} - -// TableStats 统计信息 -type TableStats struct { - MemTableSize int64 - MemTableCount int - SSTCount int - TotalRows int64 -} - -// GetVersionSet 获取 VersionSet(用于高级操作) -func (e *Engine) GetVersionSet() *VersionSet { - return e.versionSet -} - -// GetCompactionManager 获取 Compaction Manager(用于高级操作) -func (e *Engine) GetCompactionManager() *CompactionManager { - return e.compactionManager -} - -// GetMemtableManager 获取 Memtable Manager -func (e *Engine) GetMemtableManager() *MemTableManager { - return e.memtableManager -} - -// GetSSTManager 获取 SST Manager -func (e *Engine) GetSSTManager() *SSTableManager { - return e.sstManager -} - -// GetMaxSeq 获取当前最大的 seq 号 -func (e *Engine) GetMaxSeq() int64 { - return e.seq.Load() - 1 // seq 是下一个要分配的,所以最大的是 seq - 1 -} - -// GetSchema 获取 Schema -func (e *Engine) GetSchema() *Schema { - return e.schema -} - -// Stats 获取统计信息 -func (e *Engine) Stats() *TableStats { - memStats := e.memtableManager.GetStats() - sstStats := e.sstManager.GetStats() - - stats := &TableStats{ - MemTableSize: memStats.TotalSize, - MemTableCount: memStats.TotalCount, - SSTCount: sstStats.FileCount, - } - - // 计算总行数 - stats.TotalRows = int64(memStats.TotalCount) - readers := e.sstManager.GetReaders() - for _, reader := range readers { - header := reader.GetHeader() - stats.TotalRows += header.RowCount - } - - return stats -} - -// CreateIndex 创建索引 -func (e *Engine) CreateIndex(field string) error { - if e.indexManager == nil { - return fmt.Errorf("no schema defined, cannot create index") - } - - return e.indexManager.CreateIndex(field) -} - -// DropIndex 删除索引 -func (e *Engine) DropIndex(field string) error { - if e.indexManager == nil { - return fmt.Errorf("no schema defined, cannot drop index") - } - - return e.indexManager.DropIndex(field) -} - -// ListIndexes 列出所有索引 -func (e *Engine) ListIndexes() []string { - if e.indexManager == nil { - return nil - } - - return e.indexManager.ListIndexes() -} - -// GetIndexMetadata 获取索引元数据 -func (e *Engine) GetIndexMetadata() map[string]IndexMetadata { - if e.indexManager == nil { - return nil - } - - return e.indexManager.GetIndexMetadata() -} - -// RepairIndexes 手动修复索引 -func (e *Engine) RepairIndexes() error { - return e.verifyAndRepairIndexes() -} - -// Query 创建查询构建器 -func (e *Engine) Query() *QueryBuilder { - return newQueryBuilder(e) -} - -// verifyAndRepairIndexes 验证并修复索引 -func (e *Engine) verifyAndRepairIndexes() error { - if e.indexManager == nil { - return nil - } - - // 获取当前最大 seq - currentMaxSeq := e.seq.Load() - - // 创建 getData 函数 - getData := func(seq int64) (map[string]any, error) { - row, err := e.Get(seq) - if err != nil { - return nil, err - } - return row.Data, nil - } - - // 验证并修复 - return e.indexManager.VerifyAndRepair(currentMaxSeq, getData) -} diff --git a/engine_clean_test.go b/engine_clean_test.go deleted file mode 100644 index ff51d69..0000000 --- a/engine_clean_test.go +++ /dev/null @@ -1,244 +0,0 @@ -package srdb - -import ( - "os" - "testing" - "time" -) - -func TestEngineClean(t *testing.T) { - dir := "./test_clean_data" - defer os.RemoveAll(dir) - - // 1. 创建 Engine 并插入数据 - engine, err := OpenEngine(&EngineOptions{ - Dir: dir, - }) - if err != nil { - t.Fatal(err) - } - - // 插入一些数据 - for i := 0; i < 100; i++ { - err := engine.Insert(map[string]any{ - "id": i, - "name": "test", - }) - if err != nil { - t.Fatal(err) - } - } - - // 强制 flush - engine.Flush() - time.Sleep(500 * time.Millisecond) - - // 验证数据存在 - stats := engine.Stats() - t.Logf("Before Clean: MemTable=%d, SST=%d, Total=%d", - stats.MemTableCount, stats.SSTCount, stats.TotalRows) - - if stats.TotalRows == 0 { - t.Errorf("Expected some rows, got 0") - } - - // 2. 清除数据 - err = engine.Clean() - if err != nil { - t.Fatal(err) - } - - // 3. 验证数据已清除 - stats = engine.Stats() - t.Logf("After Clean: MemTable=%d, SST=%d, Total=%d", - stats.MemTableCount, stats.SSTCount, stats.TotalRows) - - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) - } - - // 4. 验证 Engine 仍然可用 - err = engine.Insert(map[string]any{ - "id": 1, - "name": "after_clean", - }) - if err != nil { - t.Fatal(err) - } - - stats = engine.Stats() - if stats.TotalRows != 1 { - t.Errorf("Expected 1 row after insert, got %d", stats.TotalRows) - } - - engine.Close() -} - -func TestEngineDestroy(t *testing.T) { - dir := "./test_destroy_data" - defer os.RemoveAll(dir) - - // 1. 创建 Engine 并插入数据 - engine, err := OpenEngine(&EngineOptions{ - Dir: dir, - }) - if err != nil { - t.Fatal(err) - } - - // 插入一些数据 - for i := 0; i < 50; i++ { - err := engine.Insert(map[string]any{ - "id": i, - "name": "test", - }) - if err != nil { - t.Fatal(err) - } - } - - // 验证数据存在 - stats := engine.Stats() - t.Logf("Before Destroy: MemTable=%d, SST=%d, Total=%d", - stats.MemTableCount, stats.SSTCount, stats.TotalRows) - - // 2. 销毁 Engine - err = engine.Destroy() - if err != nil { - t.Fatal(err) - } - - // 3. 验证数据目录已删除 - if _, err := os.Stat(dir); !os.IsNotExist(err) { - t.Errorf("Data directory should be deleted") - } - - // 4. 验证 Engine 不可用(尝试插入会失败) - err = engine.Insert(map[string]any{ - "id": 1, - "name": "after_destroy", - }) - if err == nil { - t.Errorf("Insert should fail after destroy") - } -} - -func TestEngineCleanWithSchema(t *testing.T) { - dir := "./test_clean_schema_data" - defer os.RemoveAll(dir) - - // 定义 Schema - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"}, - {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, - }) - - // 1. 创建 Engine 并插入数据 - engine, err := OpenEngine(&EngineOptions{ - Dir: dir, - Schema: schema, - }) - if err != nil { - t.Fatal(err) - } - - // 创建索引 - err = engine.CreateIndex("id") - if err != nil { - t.Fatal(err) - } - - // 插入数据 - for i := 0; i < 50; i++ { - err := engine.Insert(map[string]any{ - "id": int64(i), - "name": "test", - }) - if err != nil { - t.Fatal(err) - } - } - - // 验证索引存在 - indexes := engine.ListIndexes() - if len(indexes) != 1 { - t.Errorf("Expected 1 index, got %d", len(indexes)) - } - - // 2. 清除数据 - err = engine.Clean() - if err != nil { - t.Fatal(err) - } - - // 3. 验证数据已清除但 Schema 和索引结构保留 - stats := engine.Stats() - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) - } - - // 验证可以继续插入(Schema 仍然有效) - err = engine.Insert(map[string]any{ - "id": int64(100), - "name": "after_clean", - }) - if err != nil { - t.Fatal(err) - } - - engine.Close() -} - -func TestEngineCleanAndReopen(t *testing.T) { - dir := "./test_clean_reopen_data" - defer os.RemoveAll(dir) - - // 1. 创建 Engine 并插入数据 - engine, err := OpenEngine(&EngineOptions{ - Dir: dir, - }) - if err != nil { - t.Fatal(err) - } - - for i := 0; i < 100; i++ { - engine.Insert(map[string]any{ - "id": i, - "name": "test", - }) - } - - // 2. 清除数据 - engine.Clean() - - // 3. 关闭并重新打开 - engine.Close() - - engine2, err := OpenEngine(&EngineOptions{ - Dir: dir, - }) - if err != nil { - t.Fatal(err) - } - defer engine2.Close() - - // 4. 验证数据为空 - stats := engine2.Stats() - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after reopen, got %d", stats.TotalRows) - } - - // 5. 验证可以插入新数据 - err = engine2.Insert(map[string]any{ - "id": 1, - "name": "new_data", - }) - if err != nil { - t.Fatal(err) - } - - stats = engine2.Stats() - if stats.TotalRows != 1 { - t.Errorf("Expected 1 row, got %d", stats.TotalRows) - } -} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..750a3ce --- /dev/null +++ b/errors.go @@ -0,0 +1,336 @@ +package srdb + +import ( + "errors" + "fmt" +) + +// ErrCode 错误码类型 +type ErrCode int + +// 错误码定义 +const ( + // 通用错误 (1000-1999) + ErrCodeNotFound ErrCode = 1000 // 数据未找到 + ErrCodeClosed ErrCode = 1001 // 对象已关闭 + ErrCodeInvalidData ErrCode = 1002 // 无效数据 + ErrCodeCorrupted ErrCode = 1003 // 数据损坏 + ErrCodeExists ErrCode = 1004 // 对象已存在 + ErrCodeInvalidParam ErrCode = 1005 // 无效参数 + + // 数据库错误 (2000-2999) + ErrCodeDatabaseNotFound ErrCode = 2000 // 数据库不存在 + ErrCodeDatabaseExists ErrCode = 2001 // 数据库已存在 + ErrCodeDatabaseClosed ErrCode = 2002 // 数据库已关闭 + + // 表错误 (3000-3999) + ErrCodeTableNotFound ErrCode = 3000 // 表不存在 + ErrCodeTableExists ErrCode = 3001 // 表已存在 + ErrCodeTableClosed ErrCode = 3002 // 表已关闭 + + // Schema 错误 (4000-4999) + ErrCodeSchemaNotFound ErrCode = 4000 // Schema 不存在 + ErrCodeSchemaInvalid ErrCode = 4001 // Schema 无效 + ErrCodeSchemaMismatch ErrCode = 4002 // Schema 不匹配 + ErrCodeSchemaValidationFailed ErrCode = 4003 // Schema 验证失败 + ErrCodeSchemaChecksumMismatch ErrCode = 4004 // Schema 校验和不匹配 + + // 字段错误 (5000-5999) + ErrCodeFieldNotFound ErrCode = 5000 // 字段不存在 + ErrCodeFieldTypeMismatch ErrCode = 5001 // 字段类型不匹配 + ErrCodeFieldRequired ErrCode = 5002 // 必填字段缺失 + + // 索引错误 (6000-6999) + ErrCodeIndexNotFound ErrCode = 6000 // 索引不存在 + ErrCodeIndexExists ErrCode = 6001 // 索引已存在 + ErrCodeIndexNotReady ErrCode = 6002 // 索引未就绪 + ErrCodeIndexCorrupted ErrCode = 6003 // 索引损坏 + + // 文件格式错误 (7000-7999) + ErrCodeInvalidFormat ErrCode = 7000 // 无效的文件格式 + ErrCodeUnsupportedVersion ErrCode = 7001 // 不支持的版本 + ErrCodeInvalidMagicNumber ErrCode = 7002 // 无效的魔数 + ErrCodeChecksumMismatch ErrCode = 7003 // 校验和不匹配 + + // WAL 错误 (8000-8999) + ErrCodeWALCorrupted ErrCode = 8000 // WAL 文件损坏 + ErrCodeWALClosed ErrCode = 8001 // WAL 已关闭 + + // SSTable 错误 (9000-9999) + ErrCodeSSTableNotFound ErrCode = 9000 // SSTable 文件不存在 + ErrCodeSSTableCorrupted ErrCode = 9001 // SSTable 文件损坏 + + // Compaction 错误 (10000-10999) + ErrCodeCompactionInProgress ErrCode = 10000 // Compaction 正在进行 + ErrCodeNoCompactionNeeded ErrCode = 10001 // 不需要 Compaction + + // 编解码错误 (11000-11999) + ErrCodeEncodeFailed ErrCode = 11000 // 编码失败 + ErrCodeDecodeFailed ErrCode = 11001 // 解码失败 +) + +// 错误码消息映射 +var errCodeMessages = map[ErrCode]string{ + // 通用错误 + ErrCodeNotFound: "not found", + ErrCodeClosed: "already closed", + ErrCodeInvalidData: "invalid data", + ErrCodeCorrupted: "data corrupted", + ErrCodeExists: "already exists", + ErrCodeInvalidParam: "invalid parameter", + + // 数据库错误 + ErrCodeDatabaseNotFound: "database not found", + ErrCodeDatabaseExists: "database already exists", + ErrCodeDatabaseClosed: "database closed", + + // 表错误 + ErrCodeTableNotFound: "table not found", + ErrCodeTableExists: "table already exists", + ErrCodeTableClosed: "table closed", + + // Schema 错误 + ErrCodeSchemaNotFound: "schema not found", + ErrCodeSchemaInvalid: "schema invalid", + ErrCodeSchemaMismatch: "schema mismatch", + ErrCodeSchemaValidationFailed: "schema validation failed", + ErrCodeSchemaChecksumMismatch: "schema checksum mismatch", + + // 字段错误 + ErrCodeFieldNotFound: "field not found", + ErrCodeFieldTypeMismatch: "field type mismatch", + ErrCodeFieldRequired: "required field missing", + + // 索引错误 + ErrCodeIndexNotFound: "index not found", + ErrCodeIndexExists: "index already exists", + ErrCodeIndexNotReady: "index not ready", + ErrCodeIndexCorrupted: "index corrupted", + + // 文件格式错误 + ErrCodeInvalidFormat: "invalid file format", + ErrCodeUnsupportedVersion: "unsupported version", + ErrCodeInvalidMagicNumber: "invalid magic number", + ErrCodeChecksumMismatch: "checksum mismatch", + + // WAL 错误 + ErrCodeWALCorrupted: "wal corrupted", + ErrCodeWALClosed: "wal closed", + + // SSTable 错误 + ErrCodeSSTableNotFound: "sstable not found", + ErrCodeSSTableCorrupted: "sstable corrupted", + + // Compaction 错误 + ErrCodeCompactionInProgress: "compaction in progress", + ErrCodeNoCompactionNeeded: "no compaction needed", + + // 编解码错误 + ErrCodeEncodeFailed: "encode failed", + ErrCodeDecodeFailed: "decode failed", +} + +// Error 错误类型 +type Error struct { + Code ErrCode // 错误码 + Message string // 错误消息 + Cause error // 原始错误 +} + +// Error 实现 error 接口 +func (e *Error) Error() string { + if e.Cause != nil { + return fmt.Sprintf("[%d] %s: %v", e.Code, e.Message, e.Cause) + } + return fmt.Sprintf("[%d] %s", e.Code, e.Message) +} + +// Unwrap 支持 errors.Is 和 errors.As +func (e *Error) Unwrap() error { + return e.Cause +} + +// Is 判断错误码是否相同 +func (e *Error) Is(target error) bool { + t, ok := target.(*Error) + if !ok { + return false + } + return e.Code == t.Code +} + +// NewError 创建新错误 +func NewError(code ErrCode, cause error) *Error { + msg, ok := errCodeMessages[code] + if !ok { + msg = "unknown error" + } + return &Error{ + Code: code, + Message: msg, + Cause: cause, + } +} + +// NewErrorf 创建带格式化消息的错误 +// 注意:如果 args 中最后一个参数是 error 类型,它会被设置为 Cause +func NewErrorf(code ErrCode, format string, args ...any) *Error { + var cause error + + // 检查最后一个参数是否为 error + if len(args) > 0 { + if err, ok := args[len(args)-1].(error); ok { + cause = err + // 从 args 中移除最后一个 error 参数 + args = args[:len(args)-1] + } + } + + return &Error{ + Code: code, + Message: fmt.Sprintf(format, args...), + Cause: cause, + } +} + +// 预定义的常用错误(向后兼容) +var ( + // ErrNotFound 数据未找到 + ErrNotFound = NewError(ErrCodeNotFound, nil) + + // ErrClosed 对象已关闭 + ErrClosed = NewError(ErrCodeClosed, nil) + + // ErrInvalidData 无效数据 + ErrInvalidData = NewError(ErrCodeInvalidData, nil) + + // ErrCorrupted 数据损坏 + ErrCorrupted = NewError(ErrCodeCorrupted, nil) +) + +// 数据库错误(向后兼容) +var ( + ErrDatabaseNotFound = NewError(ErrCodeDatabaseNotFound, nil) + ErrDatabaseExists = NewError(ErrCodeDatabaseExists, nil) + ErrDatabaseClosed = NewError(ErrCodeDatabaseClosed, nil) +) + +// 表错误(向后兼容) +var ( + ErrTableNotFound = NewError(ErrCodeTableNotFound, nil) + ErrTableExists = NewError(ErrCodeTableExists, nil) + ErrTableClosed = NewError(ErrCodeTableClosed, nil) +) + +// Schema 错误(向后兼容) +var ( + ErrSchemaNotFound = NewError(ErrCodeSchemaNotFound, nil) + ErrSchemaInvalid = NewError(ErrCodeSchemaInvalid, nil) + ErrSchemaMismatch = NewError(ErrCodeSchemaMismatch, nil) + ErrSchemaValidationFailed = NewError(ErrCodeSchemaValidationFailed, nil) + ErrSchemaChecksumMismatch = NewError(ErrCodeSchemaChecksumMismatch, nil) +) + +// 字段错误(向后兼容) +var ( + ErrFieldNotFound = NewError(ErrCodeFieldNotFound, nil) + ErrFieldTypeMismatch = NewError(ErrCodeFieldTypeMismatch, nil) + ErrFieldRequired = NewError(ErrCodeFieldRequired, nil) +) + +// 索引错误(向后兼容) +var ( + ErrIndexNotFound = NewError(ErrCodeIndexNotFound, nil) + ErrIndexExists = NewError(ErrCodeIndexExists, nil) + ErrIndexNotReady = NewError(ErrCodeIndexNotReady, nil) + ErrIndexCorrupted = NewError(ErrCodeIndexCorrupted, nil) +) + +// 文件格式错误(向后兼容) +var ( + ErrInvalidFormat = NewError(ErrCodeInvalidFormat, nil) + ErrUnsupportedVersion = NewError(ErrCodeUnsupportedVersion, nil) + ErrInvalidMagicNumber = NewError(ErrCodeInvalidMagicNumber, nil) + ErrChecksumMismatch = NewError(ErrCodeChecksumMismatch, nil) +) + +// WAL 错误(向后兼容) +var ( + ErrWALCorrupted = NewError(ErrCodeWALCorrupted, nil) + ErrWALClosed = NewError(ErrCodeWALClosed, nil) +) + +// SSTable 错误(向后兼容) +var ( + ErrSSTableNotFound = NewError(ErrCodeSSTableNotFound, nil) + ErrSSTableCorrupted = NewError(ErrCodeSSTableCorrupted, nil) +) + +// Compaction 错误(向后兼容) +var ( + ErrCompactionInProgress = NewError(ErrCodeCompactionInProgress, nil) + ErrNoCompactionNeeded = NewError(ErrCodeNoCompactionNeeded, nil) +) + +// 编解码错误(向后兼容) +var ( + ErrEncodeFailed = NewError(ErrCodeEncodeFailed, nil) + ErrDecodeFailed = NewError(ErrCodeDecodeFailed, nil) +) + +// 辅助函数 + +// GetErrorCode 获取错误码 +func GetErrorCode(err error) ErrCode { + var e *Error + if errors.As(err, &e) { + return e.Code + } + return 0 +} + +// IsError 判断错误是否匹配指定的错误码 +func IsError(err error, code ErrCode) bool { + return GetErrorCode(err) == code +} + +// IsNotFound 判断是否是 NotFound 错误 +func IsNotFound(err error) bool { + code := GetErrorCode(err) + return code == ErrCodeNotFound || + code == ErrCodeTableNotFound || + code == ErrCodeDatabaseNotFound || + code == ErrCodeIndexNotFound || + code == ErrCodeFieldNotFound || + code == ErrCodeSchemaNotFound || + code == ErrCodeSSTableNotFound +} + +// IsCorrupted 判断是否是数据损坏错误 +func IsCorrupted(err error) bool { + code := GetErrorCode(err) + return code == ErrCodeCorrupted || + code == ErrCodeWALCorrupted || + code == ErrCodeSSTableCorrupted || + code == ErrCodeIndexCorrupted || + code == ErrCodeChecksumMismatch || + code == ErrCodeSchemaChecksumMismatch +} + +// IsClosed 判断是否是已关闭错误 +func IsClosed(err error) bool { + code := GetErrorCode(err) + return code == ErrCodeClosed || + code == ErrCodeDatabaseClosed || + code == ErrCodeTableClosed || + code == ErrCodeWALClosed +} + +// WrapError 包装错误并添加上下文 +func WrapError(err error, format string, args ...any) error { + if err == nil { + return nil + } + msg := fmt.Sprintf(format, args...) + return fmt.Errorf("%s: %w", msg, err) +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..f04e832 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,185 @@ +package srdb + +import ( + "errors" + "fmt" + "testing" +) + +func TestIsNotFound(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + {"ErrNotFound", ErrNotFound, true}, + {"ErrTableNotFound", ErrTableNotFound, true}, + {"ErrDatabaseNotFound", ErrDatabaseNotFound, true}, + {"ErrIndexNotFound", ErrIndexNotFound, true}, + {"ErrFieldNotFound", ErrFieldNotFound, true}, + {"ErrSchemaNotFound", ErrSchemaNotFound, true}, + {"ErrSSTableNotFound", ErrSSTableNotFound, true}, + {"ErrTableExists", ErrTableExists, false}, + {"ErrCorrupted", ErrCorrupted, false}, + {"nil", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsNotFound(tt.err) + if result != tt.expected { + t.Errorf("IsNotFound(%v) = %v, want %v", tt.err, result, tt.expected) + } + }) + } +} + +func TestIsCorrupted(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + {"ErrCorrupted", ErrCorrupted, true}, + {"ErrWALCorrupted", ErrWALCorrupted, true}, + {"ErrSSTableCorrupted", ErrSSTableCorrupted, true}, + {"ErrIndexCorrupted", ErrIndexCorrupted, true}, + {"ErrChecksumMismatch", ErrChecksumMismatch, true}, + {"ErrSchemaChecksumMismatch", ErrSchemaChecksumMismatch, true}, + {"ErrNotFound", ErrNotFound, false}, + {"nil", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsCorrupted(tt.err) + if result != tt.expected { + t.Errorf("IsCorrupted(%v) = %v, want %v", tt.err, result, tt.expected) + } + }) + } +} + +func TestIsClosed(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + {"ErrClosed", ErrClosed, true}, + {"ErrDatabaseClosed", ErrDatabaseClosed, true}, + {"ErrTableClosed", ErrTableClosed, true}, + {"ErrWALClosed", ErrWALClosed, true}, + {"ErrNotFound", ErrNotFound, false}, + {"nil", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsClosed(tt.err) + if result != tt.expected { + t.Errorf("IsClosed(%v) = %v, want %v", tt.err, result, tt.expected) + } + }) + } +} + +func TestWrapError(t *testing.T) { + baseErr := ErrNotFound + + wrapped := WrapError(baseErr, "failed to get table %s", "users") + if wrapped == nil { + t.Fatal("WrapError returned nil") + } + + // 验证包装后的错误仍然是原始错误 + if !errors.Is(wrapped, ErrNotFound) { + t.Error("Wrapped error should be ErrNotFound") + } + + // 验证错误消息包含上下文 + errMsg := wrapped.Error() + if errMsg != "failed to get table users: [1000] not found" { + t.Errorf("Expected error message to contain context, got: %s", errMsg) + } +} + +func TestWrapErrorNil(t *testing.T) { + wrapped := WrapError(nil, "some context") + if wrapped != nil { + t.Error("WrapError(nil) should return nil") + } +} + +func TestNewError(t *testing.T) { + err := NewErrorf(ErrCodeInvalidData, "custom error: %s", "test") + if err == nil { + t.Fatal("NewErrorf returned nil") + } + + // 验证错误码 + if err.Code != ErrCodeInvalidData { + t.Errorf("Expected code %d, got %d", ErrCodeInvalidData, err.Code) + } + + // 验证错误消息 + expected := "custom error: test" + if err.Message != expected { + t.Errorf("Expected message %q, got %q", expected, err.Message) + } +} + +func TestGetErrorCode(t *testing.T) { + err := NewError(ErrCodeTableNotFound, nil) + code := GetErrorCode(err) + if code != ErrCodeTableNotFound { + t.Errorf("Expected code %d, got %d", ErrCodeTableNotFound, code) + } + + // 测试非 Error 类型 + stdErr := fmt.Errorf("standard error") + code = GetErrorCode(stdErr) + if code != 0 { + t.Errorf("Expected code 0 for standard error, got %d", code) + } +} + +func TestIsError(t *testing.T) { + tests := []struct { + name string + err error + code ErrCode + expected bool + }{ + {"exact match", NewError(ErrCodeTableNotFound, nil), ErrCodeTableNotFound, true}, + {"no match", NewError(ErrCodeTableNotFound, nil), ErrCodeDatabaseNotFound, false}, + {"wrapped error", WrapError(ErrTableNotFound, "context"), ErrCodeTableNotFound, true}, + {"standard error", fmt.Errorf("standard error"), ErrCodeTableNotFound, false}, + {"nil error", nil, ErrCodeTableNotFound, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsError(tt.err, tt.code) + if result != tt.expected { + t.Errorf("IsError(%v, %d) = %v, want %v", tt.err, tt.code, result, tt.expected) + } + }) + } +} + +func TestErrorWrapping(t *testing.T) { + // 测试多层包装 + err1 := ErrTableNotFound + err2 := WrapError(err1, "database %s", "mydb") + err3 := WrapError(err2, "operation failed") + + // 验证仍然能识别原始错误 + if !IsNotFound(err3) { + t.Error("Should recognize wrapped error as NotFound") + } + + if !errors.Is(err3, ErrTableNotFound) { + t.Error("Should be able to unwrap to ErrTableNotFound") + } +} diff --git a/examples/README.md b/examples/README.md index 24803d8..b87ba92 100644 --- a/examples/README.md +++ b/examples/README.md @@ -327,8 +327,8 @@ package main import ( "net/http" - "code.tczkiot.com/srdb" - "code.tczkiot.com/srdb/webui" + "code.tczkiot.com/wlw/srdb" + "code.tczkiot.com/wlw/srdb/webui" ) func main() { @@ -362,7 +362,7 @@ http.ListenAndServe(":8080", mux) 将 webui 工具的命令集成到你的应用: ```go -import "code.tczkiot.com/srdb/examples/webui/commands" +import "code.tczkiot.com/wlw/srdb/examples/webui/commands" // 检查数据 commands.CheckData("./mydb") diff --git a/examples/webui/commands/check_data.go b/examples/webui/commands/check_data.go index 250966e..d012059 100644 --- a/examples/webui/commands/check_data.go +++ b/examples/webui/commands/check_data.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // CheckData 检查数据库中的数据 diff --git a/examples/webui/commands/check_seq.go b/examples/webui/commands/check_seq.go index 9f07155..40236ec 100644 --- a/examples/webui/commands/check_seq.go +++ b/examples/webui/commands/check_seq.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // CheckSeq 检查特定序列号的数据 diff --git a/examples/webui/commands/dump_manifest.go b/examples/webui/commands/dump_manifest.go index 8e3ebe1..63344d6 100644 --- a/examples/webui/commands/dump_manifest.go +++ b/examples/webui/commands/dump_manifest.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // DumpManifest 导出 manifest 信息 @@ -20,12 +20,11 @@ func DumpManifest(dbPath string) { log.Fatal(err) } - engine := table.GetEngine() - versionSet := engine.GetVersionSet() + versionSet := table.GetVersionSet() version := versionSet.GetCurrent() // Check for duplicates in each level - for level := 0; level < 7; level++ { + for level := range 7 { files := version.GetLevel(level) if len(files) == 0 { continue diff --git a/examples/webui/commands/inspect_all_sst.go b/examples/webui/commands/inspect_all_sst.go index 1e14902..eb3714b 100644 --- a/examples/webui/commands/inspect_all_sst.go +++ b/examples/webui/commands/inspect_all_sst.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // InspectAllSST 检查所有 SST 文件 diff --git a/examples/webui/commands/inspect_sst.go b/examples/webui/commands/inspect_sst.go index 263c59d..5faadab 100644 --- a/examples/webui/commands/inspect_sst.go +++ b/examples/webui/commands/inspect_sst.go @@ -5,7 +5,7 @@ import ( "log" "os" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // InspectSST 检查特定 SST 文件 diff --git a/examples/webui/commands/test_fix.go b/examples/webui/commands/test_fix.go index 19e54c4..5f452ef 100644 --- a/examples/webui/commands/test_fix.go +++ b/examples/webui/commands/test_fix.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // TestFix 测试修复 diff --git a/examples/webui/commands/test_keys.go b/examples/webui/commands/test_keys.go index a5340ea..9d63c5d 100644 --- a/examples/webui/commands/test_keys.go +++ b/examples/webui/commands/test_keys.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) // TestKeys 测试键 diff --git a/examples/webui/commands/webui.go b/examples/webui/commands/webui.go index e9698e3..8b0dcff 100644 --- a/examples/webui/commands/webui.go +++ b/examples/webui/commands/webui.go @@ -9,8 +9,8 @@ import ( "slices" "time" - "code.tczkiot.com/srdb" - "code.tczkiot.com/srdb/webui" + "code.tczkiot.com/wlw/srdb" + "code.tczkiot.com/wlw/srdb/webui" ) // StartWebUI 启动 WebUI 服务器 diff --git a/examples/webui/main.go b/examples/webui/main.go index ae8eb75..311a415 100644 --- a/examples/webui/main.go +++ b/examples/webui/main.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - "code.tczkiot.com/srdb/examples/webui/commands" + "code.tczkiot.com/wlw/srdb/examples/webui/commands" ) func main() { diff --git a/go.mod b/go.mod index af2f084..f24e12e 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,7 @@ -module code.tczkiot.com/srdb +module code.tczkiot.com/wlw/srdb go 1.24.0 -require ( - github.com/edsrzf/mmap-go v1.1.0 - github.com/golang/snappy v1.0.0 -) +require github.com/edsrzf/mmap-go v1.1.0 require golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect diff --git a/go.sum b/go.sum index 8a440d9..6dfc93e 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ github.com/edsrzf/mmap-go v1.1.0 h1:6EUwBLQ/Mcr1EYLE4Tn1VdW1A4ckqCQWZBw8Hr0kjpQ= github.com/edsrzf/mmap-go v1.1.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q= -github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= -github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/index.go b/index.go index 1d8db18..0b6052d 100644 --- a/index.go +++ b/index.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "maps" "os" "path/filepath" "sync" @@ -22,16 +23,16 @@ type IndexMetadata struct { // SecondaryIndex 二级索引 type SecondaryIndex struct { - name string // 索引名称 - field string // 字段名 - fieldType FieldType // 字段类型 - file *os.File // 索引文件 - builder *BTreeBuilder // B+Tree 构建器 - reader *BTreeReader // B+Tree 读取器 - valueToSeq map[string][]int64 // 值 → seq 列表 (构建时使用) - metadata IndexMetadata // 元数据 - mu sync.RWMutex - ready bool // 索引是否就绪 + name string // 索引名称 + field string // 字段名 + fieldType FieldType // 字段类型 + file *os.File // 索引文件 + btreeReader *IndexBTreeReader // B+Tree 读取器 + valueToSeq map[string][]int64 // 值 → seq 列表 (构建时使用) + metadata IndexMetadata // 元数据 + mu sync.RWMutex + ready bool // 索引是否就绪 + useBTree bool // 是否使用 B+Tree 存储(新格式) } // NewSecondaryIndex 创建二级索引 @@ -52,7 +53,7 @@ func NewSecondaryIndex(dir, field string, fieldType FieldType) (*SecondaryIndex, }, nil } -// Add 添加索引条目 +// Add 添加索引条目(增量更新元数据) func (idx *SecondaryIndex) Add(value any, seq int64) error { idx.mu.Lock() defer idx.mu.Unlock() @@ -61,97 +62,59 @@ func (idx *SecondaryIndex) Add(value any, seq int64) error { key := fmt.Sprintf("%v", value) idx.valueToSeq[key] = append(idx.valueToSeq[key], seq) + // 增量更新元数据 O(1) + if idx.metadata.MinSeq == 0 || seq < idx.metadata.MinSeq { + idx.metadata.MinSeq = seq + } + if seq > idx.metadata.MaxSeq { + idx.metadata.MaxSeq = seq + } + idx.metadata.RowCount++ + idx.metadata.UpdatedAt = time.Now().UnixNano() + + // 首次添加时设置 CreatedAt + if idx.metadata.CreatedAt == 0 { + idx.metadata.CreatedAt = time.Now().UnixNano() + } + return nil } -// Build 构建索引并持久化 +// Build 构建索引并持久化(B+Tree 格式) func (idx *SecondaryIndex) Build() error { idx.mu.Lock() defer idx.mu.Unlock() - // 持久化索引数据到 JSON 文件 - return idx.save() -} - -// save 保存索引到磁盘 -func (idx *SecondaryIndex) save() error { - // 更新元数据 - idx.updateMetadata() - - // 创建包含元数据的数据结构 - indexData := struct { - Metadata IndexMetadata `json:"metadata"` - ValueToSeq map[string][]int64 `json:"data"` - }{ - Metadata: idx.metadata, - ValueToSeq: idx.valueToSeq, - } - - // 序列化索引数据 - data, err := json.Marshal(indexData) - if err != nil { - return err - } + // 元数据已在 Add 时增量更新,这里只更新版本号 + idx.metadata.Version++ + idx.metadata.UpdatedAt = time.Now().UnixNano() // Truncate 文件 - err = idx.file.Truncate(0) + err := idx.file.Truncate(0) if err != nil { return err } - // 写入文件 - _, err = idx.file.Seek(0, 0) - if err != nil { - return err + // 使用 B+Tree 写入器 + writer := NewIndexBTreeWriter(idx.file, idx.metadata) + + // 添加所有条目 + for value, seqs := range idx.valueToSeq { + writer.Add(value, seqs) } - _, err = idx.file.Write(data) + // 构建并写入 + err = writer.Build() if err != nil { - return err - } - - // Sync 到磁盘 - err = idx.file.Sync() - if err != nil { - return err + return fmt.Errorf("failed to build btree index: %w", err) } + idx.useBTree = true idx.ready = true return nil } -// updateMetadata 更新元数据 -func (idx *SecondaryIndex) updateMetadata() { - now := time.Now().UnixNano() - - if idx.metadata.CreatedAt == 0 { - idx.metadata.CreatedAt = now - } - idx.metadata.UpdatedAt = now - idx.metadata.Version++ - - // 计算 MinSeq, MaxSeq, RowCount - var minSeq, maxSeq int64 = -1, -1 - rowCount := int64(0) - - for _, seqs := range idx.valueToSeq { - for _, seq := range seqs { - if minSeq == -1 || seq < minSeq { - minSeq = seq - } - if maxSeq == -1 || seq > maxSeq { - maxSeq = seq - } - rowCount++ - } - } - - idx.metadata.MinSeq = minSeq - idx.metadata.MaxSeq = maxSeq - idx.metadata.RowCount = rowCount -} - -// load 从磁盘加载索引 +// load 从磁盘加载索引(支持 B+Tree 和 JSON 格式) func (idx *SecondaryIndex) load() error { // 获取文件大小 stat, err := idx.file.Stat() @@ -164,6 +127,47 @@ func (idx *SecondaryIndex) load() error { return nil } + // 读取文件头,判断格式 + headerData := make([]byte, min(int(stat.Size()), IndexHeaderSize)) + _, err = idx.file.ReadAt(headerData, 0) + if err != nil { + return err + } + + // 检查是否为 B+Tree 格式 + if len(headerData) >= 4 { + magic := binary.LittleEndian.Uint32(headerData[0:4]) + if magic == IndexMagic { + // B+Tree 格式 + return idx.loadBTree() + } + } + + // 回退到 JSON 格式(向后兼容) + return idx.loadJSON() +} + +// loadBTree 加载 B+Tree 格式的索引 +func (idx *SecondaryIndex) loadBTree() error { + reader, err := NewIndexBTreeReader(idx.file) + if err != nil { + return fmt.Errorf("failed to create btree reader: %w", err) + } + + idx.btreeReader = reader + idx.metadata = reader.GetMetadata() + idx.useBTree = true + idx.ready = true + return nil +} + +// loadJSON 加载 JSON 格式的索引(向后兼容) +func (idx *SecondaryIndex) loadJSON() error { + stat, err := idx.file.Stat() + if err != nil { + return err + } + // 读取文件内容 data := make([]byte, stat.Size()) _, err = idx.file.ReadAt(data, 0) @@ -171,27 +175,24 @@ func (idx *SecondaryIndex) load() error { return err } - // 尝试加载新格式(带元数据) + // 加载 JSON 格式 var indexData struct { Metadata IndexMetadata `json:"metadata"` ValueToSeq map[string][]int64 `json:"data"` } err = json.Unmarshal(data, &indexData) - if err == nil && indexData.ValueToSeq != nil { - // 新格式 - idx.metadata = indexData.Metadata - idx.valueToSeq = indexData.ValueToSeq - } else { - // 旧格式(兼容性) - err = json.Unmarshal(data, &idx.valueToSeq) - if err != nil { - return err - } - // 初始化元数据 - idx.updateMetadata() + if err != nil { + return fmt.Errorf("failed to unmarshal index data: %w", err) } + if indexData.ValueToSeq == nil { + return fmt.Errorf("invalid index data: missing data field") + } + + idx.metadata = indexData.Metadata + idx.valueToSeq = indexData.ValueToSeq + idx.useBTree = false idx.ready = true return nil } @@ -206,6 +207,13 @@ func (idx *SecondaryIndex) Get(value any) ([]int64, error) { } key := fmt.Sprintf("%v", value) + + // 如果使用 B+Tree,从 B+Tree 读取 + if idx.useBTree && idx.btreeReader != nil { + return idx.btreeReader.Get(key) + } + + // 否则从内存 map 读取 seqs, exists := idx.valueToSeq[key] if !exists { return nil, nil @@ -238,8 +246,8 @@ func (idx *SecondaryIndex) NeedsUpdate(currentMaxSeq int64) bool { // IncrementalUpdate 增量更新索引 func (idx *SecondaryIndex) IncrementalUpdate(getData func(int64) (map[string]any, error), fromSeq, toSeq int64) error { idx.mu.Lock() - defer idx.mu.Unlock() + addedCount := int64(0) // 遍历缺失的 seq 范围 for seq := fromSeq; seq <= toSeq; seq++ { // 获取数据 @@ -257,39 +265,40 @@ func (idx *SecondaryIndex) IncrementalUpdate(getData func(int64) (map[string]any // 添加到索引 key := fmt.Sprintf("%v", value) idx.valueToSeq[key] = append(idx.valueToSeq[key], seq) + + // 更新元数据 + if idx.metadata.MinSeq == 0 || seq < idx.metadata.MinSeq { + idx.metadata.MinSeq = seq + } + if seq > idx.metadata.MaxSeq { + idx.metadata.MaxSeq = seq + } + addedCount++ } + idx.metadata.RowCount += addedCount + idx.metadata.UpdatedAt = time.Now().UnixNano() + + // 释放锁,然后调用 Build(Build 会重新获取锁) + idx.mu.Unlock() + // 保存更新后的索引 - return idx.save() + return idx.Build() } // Close 关闭索引 func (idx *SecondaryIndex) Close() error { + // 关闭 B+Tree reader + if idx.btreeReader != nil { + idx.btreeReader.Close() + } + // 关闭文件 if idx.file != nil { return idx.file.Close() } return nil } -// encodeSeqList 编码 seq 列表 -func encodeSeqList(seqs []int64) []byte { - buf := make([]byte, 8*len(seqs)) - for i, seq := range seqs { - binary.LittleEndian.PutUint64(buf[i*8:], uint64(seq)) - } - return buf -} - -// decodeSeqList 解码 seq 列表 -func decodeSeqList(data []byte) []int64 { - count := len(data) / 8 - seqs := make([]int64, count) - for i := range count { - seqs[i] = int64(binary.LittleEndian.Uint64(data[i*8:])) - } - return seqs -} - // IndexManager 索引管理器 type IndexManager struct { dir string @@ -478,9 +487,7 @@ func (m *IndexManager) ListIndexes() []string { func (m *IndexManager) VerifyAndRepair(currentMaxSeq int64, getData func(int64) (map[string]any, error)) error { m.mu.RLock() indexes := make(map[string]*SecondaryIndex) - for k, v := range m.indexes { - indexes[k] = v - } + maps.Copy(indexes, m.indexes) m.mu.RUnlock() for field, idx := range indexes { diff --git a/index_btree.go b/index_btree.go new file mode 100644 index 0000000..3253f89 --- /dev/null +++ b/index_btree.go @@ -0,0 +1,539 @@ +package srdb + +import ( + "crypto/md5" + "encoding/binary" + "fmt" + "os" + "sort" + + "github.com/edsrzf/mmap-go" +) + +/* +索引文件存储格式 (B+Tree) + +文件结构: +┌─────────────────────────────────────────────────────────────┐ +│ Header (256 bytes) │ +├─────────────────────────────────────────────────────────────┤ +│ B+Tree 索引区 │ +│ ├─ 内部节点 (BTreeNodeSize = 4096 bytes) │ +│ │ ├─ NodeType (1 byte): 0=Leaf, 1=Internal │ +│ │ ├─ KeyCount (2 bytes): 节点中的 key 数量 │ +│ │ ├─ Keys[]: int64 数组 (8 bytes each) │ +│ │ └─ Children[]: int64 偏移量数组 (8 bytes each) │ +│ │ │ +│ └─ 叶子节点 (BTreeNodeSize = 4096 bytes) │ +│ ├─ NodeType (1 byte): 0 │ +│ ├─ KeyCount (2 bytes): 节点中的 key 数量 │ +│ ├─ Keys[]: int64 数组 (8 bytes each) │ +│ ├─ Offsets[]: int64 数据偏移量 (8 bytes each) │ +│ └─ Sizes[]: int32 数据大小 (4 bytes each) │ +├─────────────────────────────────────────────────────────────┤ +│ 数据块区 │ +│ ├─ Entry 1: │ +│ │ ├─ ValueLen (4 bytes): 字段值长度 │ +│ │ ├─ Value (N bytes): 字段值 (原始字符串) │ +│ │ ├─ SeqCount (4 bytes): seq 数量 │ +│ │ └─ Seqs (8 bytes each): seq 列表 │ +│ │ │ +│ ├─ Entry 2: ... │ +│ └─ Entry N: ... │ +└─────────────────────────────────────────────────────────────┘ + +Header 格式 (256 bytes): + Offset | Size | Field | Description + -------|------|----------------|---------------------------------- + 0 | 4 | Magic | 0x49445842 ("IDXB") + 4 | 4 | FormatVersion | 文件格式版本 (1) + 8 | 8 | IndexVersion | 索引版本号 (对应 Metadata.Version) + 16 | 8 | RootOffset | B+Tree 根节点偏移 + 24 | 8 | DataStart | 数据块起始位置 + 32 | 8 | MinSeq | 最小 seq + 40 | 8 | MaxSeq | 最大 seq + 48 | 8 | RowCount | 总行数 + 56 | 8 | CreatedAt | 创建时间 (UnixNano) + 64 | 8 | UpdatedAt | 更新时间 (UnixNano) + 72 | 184 | Reserved | 预留空间 + +索引条目格式 (变长): + Offset | Size | Field | Description + -------|-------------|------------|---------------------------------- + 0 | 4 | ValueLen | 字段值长度 (N) + 4 | N | Value | 字段值 (原始字符串,用于验证哈希冲突) + 4+N | 4 | SeqCount | seq 数量 (M) + 8+N | M * 8 | Seqs | seq 列表 (int64 数组) + +Key 生成规则: + - 使用 MD5 哈希将字符串值转为 int64 + - key = MD5(value)[0:8] 的 LittleEndian uint64 + - 存储原始 value 用于验证哈希冲突 + +查询流程: + 1. value → key (MD5 哈希) + 2. B+Tree.Get(key) → (offset, size) + 3. 读取数据块 mmap[offset:offset+size] + 4. 解码并验证原始 value (处理哈希冲突) + 5. 返回 seqs 列表 + +性能特点: + - 加载: O(1) - 只需 mmap 映射 + - 查询: O(log n) - B+Tree 查找 + - 内存: 几乎为 0 - 零拷贝 mmap + - 支持范围查询 (未来可扩展) +*/ + +const ( + IndexHeaderSize = 256 // 索引文件头大小 + IndexMagic = 0x49445842 // "IDXB" - Index B-Tree + IndexVersion = 1 // 文件格式版本 +) + +// IndexHeader 索引文件头 +type IndexHeader struct { + Magic uint32 // 魔数 "IDXB" + FormatVersion uint32 // 文件格式版本号 + IndexVersion int64 // 索引版本号(对应 IndexMetadata.Version) + RootOffset int64 // B+Tree 根节点偏移 + DataStart int64 // 数据块起始位置 + MinSeq int64 // 最小 seq + MaxSeq int64 // 最大 seq + RowCount int64 // 总行数 + CreatedAt int64 // 创建时间 + UpdatedAt int64 // 更新时间 + Reserved [184]byte // 预留空间(减少 8 字节给 IndexVersion,减少 8 字节调整对齐) +} + +// Marshal 序列化 Header +func (h *IndexHeader) Marshal() []byte { + buf := make([]byte, IndexHeaderSize) + binary.LittleEndian.PutUint32(buf[0:4], h.Magic) + binary.LittleEndian.PutUint32(buf[4:8], h.FormatVersion) + binary.LittleEndian.PutUint64(buf[8:16], uint64(h.IndexVersion)) + binary.LittleEndian.PutUint64(buf[16:24], uint64(h.RootOffset)) + binary.LittleEndian.PutUint64(buf[24:32], uint64(h.DataStart)) + binary.LittleEndian.PutUint64(buf[32:40], uint64(h.MinSeq)) + binary.LittleEndian.PutUint64(buf[40:48], uint64(h.MaxSeq)) + binary.LittleEndian.PutUint64(buf[48:56], uint64(h.RowCount)) + binary.LittleEndian.PutUint64(buf[56:64], uint64(h.CreatedAt)) + binary.LittleEndian.PutUint64(buf[64:72], uint64(h.UpdatedAt)) + copy(buf[72:], h.Reserved[:]) + return buf +} + +// UnmarshalIndexHeader 反序列化 Header +func UnmarshalIndexHeader(data []byte) *IndexHeader { + if len(data) < IndexHeaderSize { + return nil + } + h := &IndexHeader{} + h.Magic = binary.LittleEndian.Uint32(data[0:4]) + h.FormatVersion = binary.LittleEndian.Uint32(data[4:8]) + h.IndexVersion = int64(binary.LittleEndian.Uint64(data[8:16])) + h.RootOffset = int64(binary.LittleEndian.Uint64(data[16:24])) + h.DataStart = int64(binary.LittleEndian.Uint64(data[24:32])) + h.MinSeq = int64(binary.LittleEndian.Uint64(data[32:40])) + h.MaxSeq = int64(binary.LittleEndian.Uint64(data[40:48])) + h.RowCount = int64(binary.LittleEndian.Uint64(data[48:56])) + h.CreatedAt = int64(binary.LittleEndian.Uint64(data[56:64])) + h.UpdatedAt = int64(binary.LittleEndian.Uint64(data[64:72])) + copy(h.Reserved[:], data[72:IndexHeaderSize]) + return h +} + +// valueToKey 将字段值转换为 B+Tree key(使用哈希) +// +// 原理: +// - 字符串无法直接用作 B+Tree key (需要 int64) +// - 使用 MD5 哈希将任意字符串映射为 int64 +// - 取 MD5 的前 8 字节作为 key +// +// 哈希冲突处理: +// - 存储原始 value 在数据块中 +// - 查询时验证原始 value 是否匹配 +// - 冲突时返回 nil(极低概率) +// +// 示例: +// "Alice" → MD5 → 0x3bc15c8aae3e4124... → key = 0x3bc15c8aae3e4124 +func valueToKey(value string) int64 { + // 使用 MD5 的前 8 字节作为 int64 key + hash := md5.Sum([]byte(value)) + return int64(binary.LittleEndian.Uint64(hash[:8])) +} + +// encodeIndexEntry 将索引条目编码为二进制格式(零拷贝友好) +// +// 格式:[ValueLen(4B)][Value(N bytes)][SeqCount(4B)][Seq1(8B)][Seq2(8B)]... +// +// 示例: +// value = "Alice", seqs = [1, 5, 10] +// 编码结果: +// [0x05, 0x00, 0x00, 0x00] // ValueLen = 5 +// [0x41, 0x6c, 0x69, 0x63, 0x65] // "Alice" +// [0x03, 0x00, 0x00, 0x00] // SeqCount = 3 +// [0x01, 0x00, 0x00, 0x00, ...] // Seq1 = 1 +// [0x05, 0x00, 0x00, 0x00, ...] // Seq2 = 5 +// [0x0a, 0x00, 0x00, 0x00, ...] // Seq3 = 10 +// +// 总大小:4 + 5 + 4 + 3*8 = 37 bytes +func encodeIndexEntry(value string, seqs []int64) []byte { + valueBytes := []byte(value) + size := 4 + len(valueBytes) + 4 + len(seqs)*8 + buf := make([]byte, size) + + // 写入 ValueLen + binary.LittleEndian.PutUint32(buf[0:4], uint32(len(valueBytes))) + + // 写入 Value + copy(buf[4:], valueBytes) + + // 写入 SeqCount + offset := 4 + len(valueBytes) + binary.LittleEndian.PutUint32(buf[offset:offset+4], uint32(len(seqs))) + + // 写入 Seqs + offset += 4 + for i, seq := range seqs { + binary.LittleEndian.PutUint64(buf[offset+i*8:offset+(i+1)*8], uint64(seq)) + } + + return buf +} + +// decodeIndexEntry 从二进制格式解码索引条目(零拷贝) +// +// 参数: +// data: 编码后的二进制数据(来自 mmap,零拷贝) +// +// 返回: +// value: 原始字段值(用于验证哈希冲突) +// seqs: seq 列表 +// err: 解码错误 +// +// 零拷贝优化: +// - 直接从 mmap 数据中读取,不复制 +// - string(data[4:4+valueLen]) 会复制,但无法避免 +// - seqs 数组需要分配,但只复制指针大小的数据 +func decodeIndexEntry(data []byte) (value string, seqs []int64, err error) { + if len(data) < 8 { + return "", nil, fmt.Errorf("data too short: %d bytes", len(data)) + } + + // 读取 ValueLen + valueLen := binary.LittleEndian.Uint32(data[0:4]) + if len(data) < int(4+valueLen+4) { + return "", nil, fmt.Errorf("data too short for value: expected %d, got %d", 4+valueLen+4, len(data)) + } + + // 读取 Value + value = string(data[4 : 4+valueLen]) + + // 读取 SeqCount + offset := 4 + int(valueLen) + seqCount := binary.LittleEndian.Uint32(data[offset : offset+4]) + + // 验证数据长度 + expectedSize := offset + 4 + int(seqCount)*8 + if len(data) < expectedSize { + return "", nil, fmt.Errorf("data too short for seqs: expected %d, got %d", expectedSize, len(data)) + } + + // 读取 Seqs + seqs = make([]int64, seqCount) + offset += 4 + for i := 0; i < int(seqCount); i++ { + seqs[i] = int64(binary.LittleEndian.Uint64(data[offset+i*8 : offset+(i+1)*8])) + } + + return value, seqs, nil +} + +// IndexBTreeWriter 使用 B+Tree 写入索引 +// +// 写入流程: +// 1. Add(): 收集所有 (value, seqs) 到内存 +// 2. Build(): +// a. 计算所有 value 的 key (MD5 哈希) +// b. 按 key 排序(B+Tree 要求有序) +// c. 编码所有条目为二进制格式 +// d. 构建 B+Tree 索引 +// e. 写入 Header + B+Tree + 数据块 +// +// 文件布局: +// [Header] → [B+Tree] → [Data Blocks] +type IndexBTreeWriter struct { + file *os.File + header IndexHeader + entries map[string][]int64 // value -> seqs + dataOffset int64 +} + +// NewIndexBTreeWriter 创建索引写入器 +func NewIndexBTreeWriter(file *os.File, metadata IndexMetadata) *IndexBTreeWriter { + return &IndexBTreeWriter{ + file: file, + header: IndexHeader{ + Magic: IndexMagic, + FormatVersion: IndexVersion, + IndexVersion: metadata.Version, + MinSeq: metadata.MinSeq, + MaxSeq: metadata.MaxSeq, + RowCount: metadata.RowCount, + CreatedAt: metadata.CreatedAt, + UpdatedAt: metadata.UpdatedAt, + }, + entries: make(map[string][]int64), + dataOffset: IndexHeaderSize, + } +} + +// Add 添加索引条目 +func (w *IndexBTreeWriter) Add(value string, seqs []int64) { + w.entries[value] = seqs +} + +// Build 构建并写入索引文件 +func (w *IndexBTreeWriter) Build() error { + // 1. 计算所有 key 并按 key 排序(确保 B+Tree 构建有序) + type valueKey struct { + value string + key int64 + } + var valueKeys []valueKey + for value := range w.entries { + valueKeys = append(valueKeys, valueKey{ + value: value, + key: valueToKey(value), + }) + } + + // 按 key 排序(而不是按字符串) + sort.Slice(valueKeys, func(i, j int) bool { + return valueKeys[i].key < valueKeys[j].key + }) + + // 2. 先写入数据块并记录位置 + type keyOffset struct { + key int64 + offset int64 + size int32 + } + var keyOffsets []keyOffset + + // 预留 Header 空间 + currentOffset := int64(IndexHeaderSize) + + // 构建数据块(使用二进制格式,无压缩) + var dataBlocks [][]byte + for _, vk := range valueKeys { + value := vk.value + seqs := w.entries[value] + + // 编码为二进制格式 + binaryData := encodeIndexEntry(value, seqs) + + // 记录 key 和数据位置(key 已经在 vk 中) + key := vk.key + dataBlocks = append(dataBlocks, binaryData) + + // 暂时不知道确切的 offset,先占位 + keyOffsets = append(keyOffsets, keyOffset{ + key: key, + offset: 0, // 稍后填充 + size: int32(len(binaryData)), + }) + + currentOffset += int64(len(binaryData)) + } + + // 3. 计算 B+Tree 起始位置(紧接 Header) + btreeStart := int64(IndexHeaderSize) + + // 估算 B+Tree 大小(每个叶子节点最多 BTreeOrder 个条目) + numEntries := len(keyOffsets) + numLeafNodes := (numEntries + BTreeOrder - 1) / BTreeOrder + + // 计算所有层级的节点总数 + totalNodes := numLeafNodes + nodesAtCurrentLevel := numLeafNodes + for nodesAtCurrentLevel > 1 { + nodesAtCurrentLevel = (nodesAtCurrentLevel + BTreeOrder - 1) / BTreeOrder + totalNodes += nodesAtCurrentLevel + } + + btreeSize := int64(totalNodes * BTreeNodeSize) + dataStart := btreeStart + btreeSize + + // 4. 更新数据块的实际偏移量 + currentDataOffset := dataStart + for i := range keyOffsets { + keyOffsets[i].offset = currentDataOffset + currentDataOffset += int64(keyOffsets[i].size) + } + + // 5. 写入 Header(预留位置) + w.header.DataStart = dataStart + w.file.WriteAt(w.header.Marshal(), 0) + + // 6. 构建 B+Tree + builder := NewBTreeBuilder(w.file, btreeStart) + for _, ko := range keyOffsets { + err := builder.Add(ko.key, ko.offset, ko.size) + if err != nil { + return fmt.Errorf("failed to add to btree: %w", err) + } + } + + rootOffset, err := builder.Build() + if err != nil { + return fmt.Errorf("failed to build btree: %w", err) + } + + // 7. 写入数据块 + currentDataOffset = dataStart + for _, data := range dataBlocks { + _, err := w.file.WriteAt(data, currentDataOffset) + if err != nil { + return fmt.Errorf("failed to write data block: %w", err) + } + currentDataOffset += int64(len(data)) + } + + // 8. 更新 Header(写入正确的 RootOffset) + w.header.RootOffset = rootOffset + _, err = w.file.WriteAt(w.header.Marshal(), 0) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + // 9. Sync 到磁盘 + return w.file.Sync() +} + +// IndexBTreeReader 使用 B+Tree 读取索引 +// +// 读取流程: +// 1. mmap 映射整个文件(零拷贝) +// 2. 读取 Header +// 3. 创建 BTreeReader (指向 RootOffset) +// 4. Get(value): +// a. value → key (MD5 哈希) +// b. BTree.Get(key) → (offset, size) +// c. 读取 mmap[offset:offset+size](零拷贝) +// d. 解码并验证原始 value +// e. 返回 seqs +// +// 性能优化: +// - mmap 零拷贝:不需要加载整个文件到内存 +// - B+Tree 索引:O(log n) 查询 +// - 按需读取:只读取需要的数据块 +type IndexBTreeReader struct { + file *os.File + mmap mmap.MMap + header IndexHeader + btree *BTreeReader +} + +// NewIndexBTreeReader 创建索引读取器 +func NewIndexBTreeReader(file *os.File) (*IndexBTreeReader, error) { + // 读取 Header + headerData := make([]byte, IndexHeaderSize) + _, err := file.ReadAt(headerData, 0) + if err != nil { + return nil, fmt.Errorf("failed to read header: %w", err) + } + + header := UnmarshalIndexHeader(headerData) + if header == nil || header.Magic != IndexMagic { + return nil, fmt.Errorf("invalid index file: bad magic") + } + + // mmap 整个文件 + mmapData, err := mmap.Map(file, mmap.RDONLY, 0) + if err != nil { + return nil, fmt.Errorf("failed to mmap index file: %w", err) + } + + // 创建 B+Tree Reader + btree := NewBTreeReader(mmapData, header.RootOffset) + + return &IndexBTreeReader{ + file: file, + mmap: mmapData, + header: *header, + btree: btree, + }, nil +} + +// Get 查询字段值对应的 seq 列表(零拷贝) +// +// 参数: +// value: 字段值(例如 "Alice") +// +// 返回: +// seqs: seq 列表(例如 [1, 5, 10]) +// err: 查询错误 +// +// 查询流程: +// 1. value → key (MD5 哈希) +// 2. B+Tree.Get(key) → (offset, size) +// 3. 读取 mmap[offset:offset+size](零拷贝) +// 4. 解码并验证原始 value(处理哈希冲突) +// 5. 返回 seqs +// +// 哈希冲突处理: +// - 如果 storedValue != value,说明发生哈希冲突 +// - 返回 nil(表示未找到) +// - 冲突概率极低(MD5 64位空间) +func (r *IndexBTreeReader) Get(value string) ([]int64, error) { + // 计算 key + key := valueToKey(value) + + // 在 B+Tree 中查找 + dataOffset, dataSize, found := r.btree.Get(key) + if !found { + return nil, nil + } + + // 读取数据块(零拷贝) + if dataOffset+int64(dataSize) > int64(len(r.mmap)) { + return nil, fmt.Errorf("data offset out of range: offset=%d, size=%d, mmap_len=%d", dataOffset, dataSize, len(r.mmap)) + } + + binaryData := r.mmap[dataOffset : dataOffset+int64(dataSize)] + + // 解码二进制数据 + storedValue, seqs, err := decodeIndexEntry(binaryData) + if err != nil { + return nil, fmt.Errorf("failed to decode entry: %w", err) + } + + // 验证原始值(处理哈希冲突) + if storedValue != value { + // 哈希冲突,返回空 + return nil, nil + } + + return seqs, nil +} + +// GetMetadata 获取元数据 +func (r *IndexBTreeReader) GetMetadata() IndexMetadata { + return IndexMetadata{ + Version: r.header.IndexVersion, + MinSeq: r.header.MinSeq, + MaxSeq: r.header.MaxSeq, + RowCount: r.header.RowCount, + CreatedAt: r.header.CreatedAt, + UpdatedAt: r.header.UpdatedAt, + } +} + +// Close 关闭读取器 +func (r *IndexBTreeReader) Close() error { + if r.mmap != nil { + r.mmap.Unmap() + } + return nil +} diff --git a/index_btree_test.go b/index_btree_test.go new file mode 100644 index 0000000..9cc65d9 --- /dev/null +++ b/index_btree_test.go @@ -0,0 +1,628 @@ +package srdb + +import ( + "fmt" + "os" + "testing" + "time" +) + +func TestIndexBTreeBasic(t *testing.T) { + // 创建临时目录 + tmpDir := t.TempDir() + + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64}, + {Name: "name", Type: FieldTypeString}, + {Name: "city", Type: FieldTypeString}, + }) + + // 创建索引管理器 + mgr := NewIndexManager(tmpDir, schema) + defer mgr.Close() + + // 创建索引 + err := mgr.CreateIndex("city") + if err != nil { + t.Fatalf("Failed to create index: %v", err) + } + + // 添加测试数据 + testData := []struct { + city string + seq int64 + }{ + {"Beijing", 1}, + {"Shanghai", 2}, + {"Beijing", 3}, + {"Shenzhen", 4}, + {"Shanghai", 5}, + {"Beijing", 6}, + } + + for _, td := range testData { + data := map[string]any{ + "id": td.seq, + "name": "user_" + string(rune(td.seq)), + "city": td.city, + } + err := mgr.AddToIndexes(data, td.seq) + if err != nil { + t.Fatalf("Failed to add to index: %v", err) + } + } + + // 构建索引 + err = mgr.BuildAll() + if err != nil { + t.Fatalf("Failed to build index: %v", err) + } + + // 关闭并重新打开,测试持久化 + mgr.Close() + + // 重新打开 + mgr2 := NewIndexManager(tmpDir, schema) + defer mgr2.Close() + + // 查询索引 + idx, exists := mgr2.GetIndex("city") + if !exists { + t.Fatal("Index not found after reload") + } + + // 验证索引使用 B+Tree + if !idx.useBTree { + t.Error("Index should be using B+Tree format") + } + + // 验证查询结果 + testCases := []struct { + city string + expectedSeqs []int64 + }{ + {"Beijing", []int64{1, 3, 6}}, + {"Shanghai", []int64{2, 5}}, + {"Shenzhen", []int64{4}}, + {"Unknown", nil}, + } + + for _, tc := range testCases { + seqs, err := idx.Get(tc.city) + if err != nil { + t.Errorf("Failed to query index for %s: %v", tc.city, err) + continue + } + + if len(seqs) != len(tc.expectedSeqs) { + t.Errorf("City %s: expected %d seqs, got %d", tc.city, len(tc.expectedSeqs), len(seqs)) + continue + } + + // 验证 seq 值 + seqMap := make(map[int64]bool) + for _, seq := range seqs { + seqMap[seq] = true + } + + for _, expectedSeq := range tc.expectedSeqs { + if !seqMap[expectedSeq] { + t.Errorf("City %s: missing expected seq %d", tc.city, expectedSeq) + } + } + } + + // 验证元数据 + metadata := idx.GetMetadata() + if metadata.MinSeq != 1 || metadata.MaxSeq != 6 || metadata.RowCount != 6 { + t.Errorf("Invalid metadata: MinSeq=%d, MaxSeq=%d, RowCount=%d", + metadata.MinSeq, metadata.MaxSeq, metadata.RowCount) + } +} + +func TestIndexBTreeLargeDataset(t *testing.T) { + // 创建临时目录 + tmpDir := t.TempDir() + + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64}, + {Name: "category", Type: FieldTypeString}, + }) + + // 创建索引管理器 + mgr := NewIndexManager(tmpDir, schema) + defer mgr.Close() + + // 创建索引 + err := mgr.CreateIndex("category") + if err != nil { + t.Fatalf("Failed to create index: %v", err) + } + + // 添加大量测试数据 + numRecords := 10000 + numCategories := 100 + + for i := 1; i <= numRecords; i++ { + category := "cat_" + string(rune('A'+(i%numCategories))) + data := map[string]any{ + "id": int64(i), + "category": category, + } + err := mgr.AddToIndexes(data, int64(i)) + if err != nil { + t.Fatalf("Failed to add to index: %v", err) + } + } + + // 构建索引 + startBuild := time.Now() + err = mgr.BuildAll() + if err != nil { + t.Fatalf("Failed to build index: %v", err) + } + buildTime := time.Since(startBuild) + t.Logf("Built index with %d records in %v", numRecords, buildTime) + + // 获取索引文件大小 + idx, _ := mgr.GetIndex("category") + stat, _ := idx.file.Stat() + t.Logf("Index file size: %d bytes", stat.Size()) + + // 关闭并重新打开 + mgr.Close() + + // 重新打开 + mgr2 := NewIndexManager(tmpDir, schema) + defer mgr2.Close() + + idx2, exists := mgr2.GetIndex("category") + if !exists { + t.Fatal("Index not found after reload") + } + + // 验证索引使用 B+Tree + if !idx2.useBTree { + t.Error("Index should be using B+Tree format") + } + + // 随机查询测试 + startQuery := time.Now() + for i := 0; i < 100; i++ { + category := "cat_" + string(rune('A'+(i%numCategories))) + seqs, err := idx2.Get(category) + if err != nil { + t.Errorf("Failed to query index for %s: %v", category, err) + } + // 验证返回的 seq 数量 + expectedCount := numRecords / numCategories + if len(seqs) != expectedCount { + t.Errorf("Category %s: expected %d seqs, got %d", category, expectedCount, len(seqs)) + } + } + queryTime := time.Since(startQuery) + t.Logf("Queried 100 categories in %v (avg: %v per query)", queryTime, queryTime/100) +} + +func TestIndexBTreeBackwardCompatibility(t *testing.T) { + // 创建临时目录 + tmpDir := t.TempDir() + + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64}, + {Name: "status", Type: FieldTypeString}, + }) + + // 1. 创建索引管理器并用旧方式(通过先禁用新格式)创建索引 + mgr := NewIndexManager(tmpDir, schema) + + // 创建索引 + err := mgr.CreateIndex("status") + if err != nil { + t.Fatalf("Failed to create index: %v", err) + } + + // 添加���据 + testData := map[string][]int64{ + "active": {1, 3, 5}, + "inactive": {2, 4}, + } + + for status, seqs := range testData { + for _, seq := range seqs { + data := map[string]any{ + "id": seq, + "status": status, + } + err := mgr.AddToIndexes(data, seq) + if err != nil { + t.Fatalf("Failed to add to index: %v", err) + } + } + } + + // 构建索引(会使用新的 B+Tree 格式) + err = mgr.BuildAll() + if err != nil { + t.Fatalf("Failed to build index: %v", err) + } + + // 关闭 + mgr.Close() + + // 2. 重新加载并验证 + mgr2 := NewIndexManager(tmpDir, schema) + defer mgr2.Close() + + idx, exists := mgr2.GetIndex("status") + if !exists { + t.Fatal("Failed to load index") + } + + // 应该使用 B+Tree 格式 + if !idx.useBTree { + t.Error("Index should be using B+Tree format") + } + + // 验证查询结果 + seqs, err := idx.Get("active") + if err != nil || len(seqs) != 3 { + t.Errorf("Failed to query index: err=%v, seqs=%v", err, seqs) + } + + seqs, err = idx.Get("inactive") + if err != nil || len(seqs) != 2 { + t.Errorf("Failed to query index: err=%v, seqs=%v", err, seqs) + } + + t.Log("Successfully loaded and queried B+Tree format index") +} + +func TestIndexBTreeIncrementalUpdate(t *testing.T) { + // 创建临时目录 + tmpDir := t.TempDir() + + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64}, + {Name: "tag", Type: FieldTypeString}, + }) + + // 创建索引管理器 + mgr := NewIndexManager(tmpDir, schema) + defer mgr.Close() + + // 创建索引 + err := mgr.CreateIndex("tag") + if err != nil { + t.Fatalf("Failed to create index: %v", err) + } + + // 添加初始数据 + for i := 1; i <= 100; i++ { + tag := "tag_" + string(rune('A'+(i%10))) + data := map[string]any{ + "id": int64(i), + "tag": tag, + } + err := mgr.AddToIndexes(data, int64(i)) + if err != nil { + t.Fatalf("Failed to add to index: %v", err) + } + } + + // 构建索引 + err = mgr.BuildAll() + if err != nil { + t.Fatalf("Failed to build index: %v", err) + } + + // 获取索引 + idx, _ := mgr.GetIndex("tag") + + // 验证初始元数据 + metadata := idx.GetMetadata() + if metadata.MaxSeq != 100 { + t.Errorf("Expected MaxSeq=100, got %d", metadata.MaxSeq) + } + + // 增量更新:添加新数据 + getData := func(seq int64) (map[string]any, error) { + tag := "tag_" + string(rune('A'+(int(seq)%10))) + return map[string]any{ + "id": seq, + "tag": tag, + }, nil + } + + err = idx.IncrementalUpdate(getData, 101, 200) + if err != nil { + t.Fatalf("Failed to incremental update: %v", err) + } + + // 验证更新后的元数据 + metadata = idx.GetMetadata() + if metadata.MaxSeq != 200 { + t.Errorf("Expected MaxSeq=200 after update, got %d", metadata.MaxSeq) + } + if metadata.RowCount != 200 { + t.Errorf("Expected RowCount=200 after update, got %d", metadata.RowCount) + } + + // 验证可以查询到新数据 + seqs, err := idx.Get("tag_A") + if err != nil { + t.Fatalf("Failed to query index: %v", err) + } + + // tag_A 应该包含 seq 1, 11, 21, ..., 191 (20个) + if len(seqs) != 20 { + t.Errorf("Expected 20 seqs for tag_A, got %d", len(seqs)) + } + + t.Logf("Incremental update successful: %d records indexed", metadata.RowCount) +} + +// TestIndexBTreeWriter 测试 B+Tree 写入器 +func TestIndexBTreeWriter(t *testing.T) { + tmpDir := t.TempDir() + filePath := tmpDir + "/test_idx.sst" + + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + defer file.Close() + + // 创建写入器 + metadata := IndexMetadata{ + MinSeq: 1, + MaxSeq: 10, + RowCount: 10, + CreatedAt: time.Now().UnixNano(), + UpdatedAt: time.Now().UnixNano(), + } + + writer := NewIndexBTreeWriter(file, metadata) + + // 添加测试数据 + testData := map[string][]int64{ + "apple": {1, 2, 3}, + "banana": {4, 5}, + "cherry": {6, 7, 8, 9}, + "date": {10}, + } + + for value, seqs := range testData { + writer.Add(value, seqs) + } + + // 构建索引 + err = writer.Build() + if err != nil { + t.Fatalf("Failed to build index: %v", err) + } + + t.Log("B+Tree index written successfully") + + // 关闭并重新打开文件 + file.Close() + + // 读取索引 + file, err = os.Open(filePath) + if err != nil { + t.Fatalf("Failed to reopen file: %v", err) + } + + reader, err := NewIndexBTreeReader(file) + if err != nil { + t.Fatalf("Failed to create reader: %v", err) + } + defer reader.Close() + + // 验证读取 + for value, expectedSeqs := range testData { + seqs, err := reader.Get(value) + if err != nil { + t.Errorf("Failed to get %s: %v", value, err) + continue + } + + if len(seqs) != len(expectedSeqs) { + t.Errorf("%s: expected %d seqs, got %d", value, len(expectedSeqs), len(seqs)) + t.Logf(" Expected: %v", expectedSeqs) + t.Logf(" Got: %v", seqs) + } else { + // 验证每个 seq + seqMap := make(map[int64]bool) + for _, seq := range seqs { + seqMap[seq] = true + } + for _, expectedSeq := range expectedSeqs { + if !seqMap[expectedSeq] { + t.Errorf("%s: missing seq %d", value, expectedSeq) + } + } + } + } + + // 测试不存在的值 + seqs, err := reader.Get("unknown") + if err != nil { + t.Errorf("Failed to get unknown: %v", err) + } + if seqs != nil { + t.Errorf("Expected nil for unknown value, got %v", seqs) + } + + t.Log("All B+Tree reads successful") +} + +// TestValueToKey 测试哈希函数 +func TestValueToKey(t *testing.T) { + testCases := []string{ + "apple", + "banana", + "cherry", + "Beijing", + "Shanghai", + "Shenzhen", + } + + keyMap := make(map[int64]string) + for _, value := range testCases { + key := valueToKey(value) + t.Logf("valueToKey(%s) = %d", value, key) + + // 检查哈希冲突 + if existingValue, exists := keyMap[key]; exists { + t.Errorf("Hash collision: %s and %s both hash to %d", value, existingValue, key) + } + keyMap[key] = value + + // 验证哈希的一致性 + key2 := valueToKey(value) + if key != key2 { + t.Errorf("Hash inconsistency for %s: %d != %d", value, key, key2) + } + } +} + +// TestIndexBTreeDataTypes 测试不同数据类型 +func TestIndexBTreeDataTypes(t *testing.T) { + tmpDir := t.TempDir() + filePath := tmpDir + "/test_types.sst" + + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + defer file.Close() + + metadata := IndexMetadata{ + MinSeq: 1, + MaxSeq: 5, + RowCount: 5, + CreatedAt: time.Now().UnixNano(), + UpdatedAt: time.Now().UnixNano(), + } + + writer := NewIndexBTreeWriter(file, metadata) + + // 测试不同类型的值(都转换为字符串) + testData := map[string][]int64{ + "123": {1}, // 数字字符串 + "true": {2}, // 布尔字符串 + "hello": {3}, // 普通字符串 + "世界": {4}, // 中文 + "": {5}, // 空字符串 + } + + for value, seqs := range testData { + writer.Add(value, seqs) + } + + err = writer.Build() + if err != nil { + t.Fatalf("Failed to build: %v", err) + } + + file.Close() + + // 重新读取 + file, err = os.Open(filePath) + if err != nil { + t.Fatalf("Failed to reopen: %v", err) + } + + reader, err := NewIndexBTreeReader(file) + if err != nil { + t.Fatalf("Failed to create reader: %v", err) + } + defer reader.Close() + + // 验证 + for value, expectedSeqs := range testData { + seqs, err := reader.Get(value) + if err != nil { + t.Errorf("Failed to get '%s': %v", value, err) + continue + } + + if len(seqs) != len(expectedSeqs) { + t.Errorf("'%s': expected %d seqs, got %d", value, len(expectedSeqs), len(seqs)) + } + } + + t.Log("All data types tested successfully") +} + +// 测试大数据量 +func TestIndexBTreeLargeData(t *testing.T) { + tmpDir := t.TempDir() + filePath := tmpDir + "/test_large.sst" + + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + defer file.Close() + + metadata := IndexMetadata{ + MinSeq: 1, + MaxSeq: 1000, + RowCount: 1000, + CreatedAt: time.Now().UnixNano(), + UpdatedAt: time.Now().UnixNano(), + } + + writer := NewIndexBTreeWriter(file, metadata) + + // 添加 1000 个不同的值 + for i := range 1000 { + value := fmt.Sprintf("value_%d", i) + seqs := []int64{int64(i + 1)} + writer.Add(value, seqs) + } + + err = writer.Build() + if err != nil { + t.Fatalf("Failed to build: %v", err) + } + + fileInfo, _ := file.Stat() + t.Logf("Index file size: %d bytes", fileInfo.Size()) + + file.Close() + + // 重新读取 + file, err = os.Open(filePath) + if err != nil { + t.Fatalf("Failed to reopen: %v", err) + } + + reader, err := NewIndexBTreeReader(file) + if err != nil { + t.Fatalf("Failed to create reader: %v", err) + } + defer reader.Close() + + // 随机验证 100 个值 + for i := range 100 { + value := fmt.Sprintf("value_%d", i*10) + seqs, err := reader.Get(value) + if err != nil { + t.Errorf("Failed to get %s: %v", value, err) + continue + } + + if len(seqs) != 1 || seqs[0] != int64(i*10+1) { + t.Errorf("%s: expected [%d], got %v", value, i*10+1, seqs) + } + } + + t.Log("Large data test successful") +} diff --git a/query.go b/query.go index 943624b..e390be1 100644 --- a/query.go +++ b/query.go @@ -31,18 +31,13 @@ func (m *mapFieldset) Get(key string) (Field, any, error) { return Field{}, nil, fmt.Errorf("field %s not found", key) } - // 如果有 schema,返回字段定义 - if m.schema != nil { - field, err := m.schema.GetField(key) - if err != nil { - // 字段在 schema 中不存在,返回默认 Field - return Field{Name: key}, value, nil - } - return *field, value, nil + // 从 Schema 获取字段定义 + field, err := m.schema.GetField(key) + if err != nil { + // 字段在 schema 中不存在,返回默认 Field + return Field{Name: key}, value, nil } - - // 没有 schema,返回默认 Field - return Field{Name: key}, value, nil + return *field, value, nil } type Expr interface { @@ -364,12 +359,12 @@ func Or(exprs ...Expr) Expr { type QueryBuilder struct { conds []Expr fields []string // 要选择的字段,nil 表示选择所有字段 - engine *Engine + table *Table } -func newQueryBuilder(engine *Engine) *QueryBuilder { +func newQueryBuilder(table *Table) *QueryBuilder { return &QueryBuilder{ - engine: engine, + table: table, } } @@ -384,7 +379,7 @@ func (qb *QueryBuilder) Match(data map[string]any) bool { return true } - fs := newMapFieldset(data, qb.engine.schema) + fs := newMapFieldset(data, qb.table.schema) for _, cond := range qb.conds { if !cond.Match(fs) { return false @@ -477,15 +472,15 @@ func (qb *QueryBuilder) NotNull(field string) *QueryBuilder { // Rows 返回所有匹配的数据(游标模式 - 惰性加载) func (qb *QueryBuilder) Rows() (*Rows, error) { - if qb.engine == nil { - return nil, fmt.Errorf("engine is nil") + if qb.table == nil { + return nil, fmt.Errorf("table is nil") } rows := &Rows{ - schema: qb.engine.schema, + schema: qb.table.schema, fields: qb.fields, qb: qb, - engine: qb.engine, + table: qb.table, visited: make(map[int64]bool), } @@ -495,7 +490,7 @@ func (qb *QueryBuilder) Rows() (*Rows, error) { var allKeys []int64 // 1. 从 Active MemTable 读取数据 - activeMemTable := qb.engine.memtableManager.GetActive() + activeMemTable := qb.table.memtableManager.GetActive() if activeMemTable != nil { activeKeys := activeMemTable.Keys() for _, key := range activeKeys { @@ -510,7 +505,7 @@ func (qb *QueryBuilder) Rows() (*Rows, error) { } // 2. 从所有 Immutable MemTables 读取数据 - immutables := qb.engine.memtableManager.GetImmutables() + immutables := qb.table.memtableManager.GetImmutables() for _, imm := range immutables { immKeys := imm.MemTable.Keys() for _, key := range immKeys { @@ -530,13 +525,13 @@ func (qb *QueryBuilder) Rows() (*Rows, error) { } // 3. 收集所有 SST 文件的 keys - sstReaders := qb.engine.sstManager.GetReaders() + sstReaders := qb.table.sstManager.GetReaders() for _, reader := range sstReaders { // 获取文件中实际存在的 key 列表(已在 GetAllKeys 中排序) keys := reader.GetAllKeys() - // 记录所有 keys(实际数据稍后统一从 engine 读取) + // 记录所有 keys(实际数据稍后统一从 table 读取) for _, key := range keys { // 如果 key 已存在(来自更新的数据源),跳过 if _, exists := keyToRow[key]; !exists { @@ -561,15 +556,15 @@ func (qb *QueryBuilder) Rows() (*Rows, error) { // 排序 slices.Sort(uniqueKeys) - // 统一从 engine 读取所有数据(避免 compaction 导致的文件删除) + // 统一从 table 读取所有数据(避免 compaction 导致的文件删除) rows.cachedRows = make([]*SSTableRow, 0, len(uniqueKeys)) for _, seq := range uniqueKeys { // 如果已经从 MemTable 读取,直接使用 row := keyToRow[seq] if row == nil { - // 从 engine 读取(会搜索 MemTable + 所有 SST,包括 compaction 后的新文件) + // 从 table 读取(会搜索 MemTable + 所有 SST,包括 compaction 后的新文件) var err error - row, err = qb.engine.Get(seq) + row, err = qb.table.Get(seq) if err != nil { // 数据不存在(理论上不应该发生,因为 key 来自索引) continue @@ -689,7 +684,7 @@ type Rows struct { schema *Schema fields []string // 要选择的字段,nil 表示选择所有字段 qb *QueryBuilder - engine *Engine + table *Table // 迭代状态 currentRow *Row @@ -762,7 +757,7 @@ func (r *Rows) next() bool { if r.memIterator != nil { if seq, ok := r.memIterator.next(); ok { if !r.visited[seq] { - row, err := r.engine.Get(seq) + row, err := r.table.Get(seq) if err == nil && r.qb.Match(row.Data) { r.visited[seq] = true r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row} @@ -780,7 +775,7 @@ func (r *Rows) next() bool { if r.immutableIterator != nil { if seq, ok := r.immutableIterator.next(); ok { if !r.visited[seq] { - row, err := r.engine.Get(seq) + row, err := r.table.Get(seq) if err == nil && r.qb.Match(row.Data) { r.visited[seq] = true r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row} @@ -796,8 +791,8 @@ func (r *Rows) next() bool { } // 检查是否有更多 Immutable MemTables - if r.immutableIterator == nil && r.immutableIndex < len(r.engine.memtableManager.GetImmutables()) { - immutables := r.engine.memtableManager.GetImmutables() + if r.immutableIterator == nil && r.immutableIndex < len(r.table.memtableManager.GetImmutables()) { + immutables := r.table.memtableManager.GetImmutables() if r.immutableIndex < len(immutables) { r.immutableIterator = newMemtableIterator(immutables[r.immutableIndex].MemTable.Keys()) continue @@ -813,7 +808,7 @@ func (r *Rows) next() bool { sstReader.index++ if !r.visited[seq] { - row, err := r.engine.Get(seq) + row, err := r.table.Get(seq) if err == nil && r.qb.Match(row.Data) { r.visited[seq] = true r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row} diff --git a/schema.go b/schema.go index 86fd026..6ae5f7a 100644 --- a/schema.go +++ b/schema.go @@ -63,7 +63,7 @@ func (s *Schema) GetField(name string) (*Field, error) { return &s.Fields[i], nil } } - return nil, fmt.Errorf("field %s not found", name) + return nil, NewErrorf(ErrCodeFieldNotFound, "field %s not found", name) } // GetIndexedFields 获取所有需要索引的字段 @@ -140,7 +140,7 @@ func (s *Schema) ExtractIndexValue(field string, data map[string]any) (any, erro value, exists := data[field] if !exists { - return nil, fmt.Errorf("field %s not found in data", field) + return nil, NewErrorf(ErrCodeFieldNotFound, "field %s not found in data", field) } // 类型转换 diff --git a/sstable.go b/sstable.go index 0d3931a..afe7231 100644 --- a/sstable.go +++ b/sstable.go @@ -3,7 +3,6 @@ package srdb import ( "bytes" "encoding/binary" - "encoding/json" "fmt" "os" "path/filepath" @@ -12,7 +11,6 @@ import ( "sync" "github.com/edsrzf/mmap-go" - "github.com/golang/snappy" ) const ( @@ -22,13 +20,9 @@ const ( SSTableHeaderSize = 256 // 文件头大小 SSTableBlockSize = 64 * 1024 // 数据块大小 (64 KB) - // 压缩类型 - SSTableCompressionNone = 0 - SSTableCompressionSnappy = 1 - // 二进制编码格式: // [Magic: 4 bytes][Seq: 8 bytes][Time: 8 bytes][DataLen: 4 bytes][Data: variable] - SSTableRowMagic = 0x524F5733 // "ROW3" + SSTableRowMagic = 0x524F5731 // "ROW1" ) // SSTableHeader SST 文件头 (256 bytes) @@ -36,7 +30,7 @@ type SSTableHeader struct { // 基础信息 (32 bytes) Magic uint32 // Magic Number: 0x53535433 Version uint32 // 版本号 - Compression uint8 // 压缩类型 + Compression uint8 // 压缩类型(保留字段用于向后兼容) Reserved1 [3]byte Flags uint32 // 标志位 Reserved2 [16]byte @@ -175,14 +169,9 @@ func encodeSSTableRowBinary(row *SSTableRow, schema *Schema) ([]byte, error) { return nil, err } - // 如果没有 Schema,回退到 JSON 编码(只编码 Data 部分,Seq/Time 已经写入) + // 强制要求 Schema if schema == nil { - dataJSON, err := json.Marshal(row.Data) - if err != nil { - return nil, err - } - buf.Write(dataJSON) - return buf.Bytes(), nil + return nil, fmt.Errorf("schema is required for encoding SSTable rows") } // 按字段分别编码和压缩 @@ -191,8 +180,8 @@ func encodeSSTableRowBinary(row *SSTableRow, schema *Schema) ([]byte, error) { return nil, err } - // 1. 先编码所有字段到各自的 buffer - compressedFields := make([][]byte, len(schema.Fields)) + // 1. 先编码所有字段到各自的 buffer(无压缩) + fieldData := make([][]byte, len(schema.Fields)) for i, field := range schema.Fields { fieldBuf := new(bytes.Buffer) @@ -209,28 +198,28 @@ func encodeSSTableRowBinary(row *SSTableRow, schema *Schema) ([]byte, error) { } } - // 压缩每个字段 - compressedFields[i] = snappy.Encode(nil, fieldBuf.Bytes()) + // 直接使用二进制数据(无压缩) + fieldData[i] = fieldBuf.Bytes() } // 2. 写入字段偏移表(相对于数据区起始位置) currentOffset := 0 - for _, compressed := range compressedFields { + for _, data := range fieldData { // 写入字段偏移(相对于数据区) if err := binary.Write(buf, binary.LittleEndian, uint32(currentOffset)); err != nil { return nil, err } - // 写入压缩后大小 - if err := binary.Write(buf, binary.LittleEndian, uint32(len(compressed))); err != nil { + // 写入数据大小 + if err := binary.Write(buf, binary.LittleEndian, uint32(len(data))); err != nil { return nil, err } - currentOffset += len(compressed) + currentOffset += len(data) } - // 3. 写入压缩后的字段数据 - for _, compressed := range compressedFields { - if _, err := buf.Write(compressed); err != nil { + // 3. 写入字段数据 + for _, data := range fieldData { + if _, err := buf.Write(data); err != nil { return nil, err } } @@ -341,14 +330,9 @@ func decodeSSTableRowBinaryPartial(data []byte, schema *Schema, fields []string) return nil, err } - // 如果没有 Schema,回退到 JSON 解码(读取剩余的 JSON 数据) + // 强制要求 Schema if schema == nil { - remainingData := data[16:] // Skip Seq (8 bytes) + Time (8 bytes) - row.Data = make(map[string]any) - if err := json.Unmarshal(remainingData, &row.Data); err != nil { - return nil, err - } - return row, nil + return nil, fmt.Errorf("schema is required for decoding SSTable rows") } // 读取字段数量 @@ -400,28 +384,22 @@ func decodeSSTableRowBinaryPartial(data []byte, schema *Schema, fields []string) continue } - // 读取压缩的字段数据 + // 读取字段数据(无压缩) info := fieldInfos[i] - compressedPos := dataStart + int64(info.offset) + fieldPos := dataStart + int64(info.offset) // Seek 到字段位置 - if _, err := buf.Seek(compressedPos, 0); err != nil { + if _, err := buf.Seek(fieldPos, 0); err != nil { return nil, fmt.Errorf("seek to field %s: %w", field.Name, err) } - compressedData := make([]byte, info.size) - if _, err := buf.Read(compressedData); err != nil { + fieldData := make([]byte, info.size) + if _, err := buf.Read(fieldData); err != nil { return nil, fmt.Errorf("read field %s: %w", field.Name, err) } - // 解压字段数据 - decompressed, err := snappy.Decode(nil, compressedData) - if err != nil { - return nil, fmt.Errorf("decompress field %s: %w", field.Name, err) - } - - // 解析字段值 - fieldBuf := bytes.NewReader(decompressed) + // 解析字段值(直接从二进制数据) + fieldBuf := bytes.NewReader(fieldData) value, err := readFieldBinaryValue(fieldBuf, field.Type, true) if err != nil { return nil, fmt.Errorf("parse field %s: %w", field.Name, err) @@ -489,31 +467,29 @@ func readFieldBinaryValue(buf *bytes.Reader, typ FieldType, keep bool) (any, err // SSTableWriter SST 文件写入器 type SSTableWriter struct { - file *os.File - builder *BTreeBuilder - dataOffset int64 - dataStart int64 // 数据起始位置 - rowCount int64 - minKey int64 - maxKey int64 - minTime int64 - maxTime int64 - compression uint8 - schema *Schema // Schema 用于优化编码 + file *os.File + builder *BTreeBuilder + dataOffset int64 + dataStart int64 // 数据起始位置 + rowCount int64 + minKey int64 + maxKey int64 + minTime int64 + maxTime int64 + schema *Schema // Schema 用于优化编码 } // NewSSTableWriter 创建 SST 写入器 func NewSSTableWriter(file *os.File, schema *Schema) *SSTableWriter { return &SSTableWriter{ - file: file, - builder: NewBTreeBuilder(file, SSTableHeaderSize), - dataOffset: 0, // 先写数据,后面会更新 - compression: SSTableCompressionSnappy, - minKey: -1, - maxKey: -1, - minTime: -1, - maxTime: -1, - schema: schema, + file: file, + builder: NewBTreeBuilder(file, SSTableHeaderSize), + dataOffset: 0, // 先写数据,后面会更新 + minKey: -1, + maxKey: -1, + minTime: -1, + maxTime: -1, + schema: schema, } } @@ -541,18 +517,13 @@ func (w *SSTableWriter) Add(row *SSTableRow) error { } w.rowCount++ - // 序列化数据(使用 Schema 优化的二进制格式) - data := encodeSSTableRow(row, w.schema) - - // 压缩数据 - var compressed []byte - if w.compression == SSTableCompressionSnappy { - compressed = snappy.Encode(nil, data) - } else { - compressed = data + // 序列化数据(使用 Schema 优化的二进制格式,无压缩) + data, err := encodeSSTableRow(row, w.schema) + if err != nil { + return fmt.Errorf("encode row: %w", err) } - // 写入数据块 + // 写入数据块(不压缩) // 第一次写入时,确定数据起始位置 if w.dataStart == 0 { // 预留足够空间给 B+Tree 索引 @@ -563,19 +534,19 @@ func (w *SSTableWriter) Add(row *SSTableRow) error { } offset := w.dataOffset - _, err := w.file.WriteAt(compressed, offset) + _, err = w.file.WriteAt(data, offset) if err != nil { return err } // 添加到 B+Tree - err = w.builder.Add(row.Seq, offset, int32(len(compressed))) + err = w.builder.Add(row.Seq, offset, int32(len(data))) if err != nil { return err } // 更新数据偏移 - w.dataOffset += int64(len(compressed)) + w.dataOffset += int64(len(data)) return nil } @@ -595,7 +566,7 @@ func (w *SSTableWriter) Finish() error { header := &SSTableHeader{ Magic: SSTableMagicNumber, Version: SSTableVersion, - Compression: w.compression, + Compression: 0, // 不使用压缩(保留字段用于向后兼容) IndexOffset: SSTableHeaderSize, IndexSize: indexSize, RootOffset: rootOffset, @@ -620,19 +591,13 @@ func (w *SSTableWriter) Finish() error { } // encodeSSTableRow 编码行数据 (使用二进制格式) -func encodeSSTableRow(row *SSTableRow, schema *Schema) []byte { +func encodeSSTableRow(row *SSTableRow, schema *Schema) ([]byte, error) { // 使用二进制格式编码 encoded, err := encodeSSTableRowBinary(row, schema) if err != nil { - // 降级到 JSON (不应该发生) - data := map[string]any{ - "_seq": row.Seq, - "_time": row.Time, - "data": row.Data, - } - encoded, _ = json.Marshal(data) + return nil, fmt.Errorf("failed to encode row: %w", err) } - return encoded + return encoded, nil } // SSTableReader SST 文件读取器 @@ -704,21 +669,9 @@ func (r *SSTableReader) Get(key int64) (*SSTableRow, error) { return nil, fmt.Errorf("invalid data offset") } - compressed := r.mmap[dataOffset : dataOffset+int64(dataSize)] + data := r.mmap[dataOffset : dataOffset+int64(dataSize)] - // 4. 解压缩 - var data []byte - var err error - if r.header.Compression == SSTableCompressionSnappy { - data, err = snappy.Decode(nil, compressed) - if err != nil { - return nil, err - } - } else { - data = compressed - } - - // 5. 反序列化 + // 4. 反序列化(无压缩) row, err := decodeSSTableRow(data, r.schema) if err != nil { return nil, err @@ -745,21 +698,9 @@ func (r *SSTableReader) GetPartial(key int64, fields []string) (*SSTableRow, err return nil, fmt.Errorf("invalid data offset") } - compressed := r.mmap[dataOffset : dataOffset+int64(dataSize)] + data := r.mmap[dataOffset : dataOffset+int64(dataSize)] - // 4. 解压缩 - var data []byte - var err error - if r.header.Compression == SSTableCompressionSnappy { - data, err = snappy.Decode(nil, compressed) - if err != nil { - return nil, err - } - } else { - data = compressed - } - - // 5. 按需反序列化(只解析需要的字段) + // 4. 按需反序列化(只解析需要的字段,无压缩) row, err := decodeSSTableRowBinaryPartial(data, r.schema, fields) if err != nil { return nil, err @@ -799,27 +740,13 @@ func (r *SSTableReader) Close() error { return nil } -// decodeSSTableRow 解码行数据 +// decodeSSTableRow 解码行数据(只支持二进制格式) func decodeSSTableRow(data []byte, schema *Schema) (*SSTableRow, error) { - // 尝试使用二进制格式解码 + // 使用二进制格式解码 row, err := decodeSSTableRowBinary(data, schema) - if err == nil { - return row, nil - } - - // 降级到 JSON (兼容旧数据) - var decoded map[string]any - err = json.Unmarshal(data, &decoded) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to decode row: %w", err) } - - row = &SSTableRow{ - Seq: int64(decoded["_seq"].(float64)), - Time: int64(decoded["_time"].(float64)), - Data: decoded["data"].(map[string]any), - } - return row, nil } diff --git a/sstable_test.go b/sstable_test.go index f6ba7bc..d87088c 100644 --- a/sstable_test.go +++ b/sstable_test.go @@ -15,8 +15,14 @@ func TestSSTable(t *testing.T) { } defer os.Remove("test.sst") + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "name", Type: FieldTypeString}, + {Name: "age", Type: FieldTypeInt64}, + }) + // 2. 写入数据 - writer := NewSSTableWriter(file, nil) + writer := NewSSTableWriter(file, schema) // 添加 1000 行数据 for i := int64(1); i <= 1000; i++ { @@ -51,6 +57,9 @@ func TestSSTable(t *testing.T) { } defer reader.Close() + // 设置 Schema + reader.SetSchema(schema) + // 验证 Header header := reader.GetHeader() if header.RowCount != 1000 { @@ -100,7 +109,7 @@ func TestSSTableHeaderSerialization(t *testing.T) { header := &SSTableHeader{ Magic: SSTableMagicNumber, Version: SSTableVersion, - Compression: SSTableCompressionSnappy, + Compression: 0, // 不使用压缩 IndexOffset: 256, IndexSize: 1024, RootOffset: 512, @@ -154,11 +163,16 @@ func TestSSTableHeaderSerialization(t *testing.T) { } func BenchmarkSSTableGet(b *testing.B) { + // 创建 Schema + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeInt64}, + }) + // 创建测试文件 file, _ := os.Create("bench.sst") defer os.Remove("bench.sst") - writer := NewSSTableWriter(file, nil) + writer := NewSSTableWriter(file, schema) for i := int64(1); i <= 10000; i++ { row := &SSTableRow{ Seq: i, @@ -176,6 +190,9 @@ func BenchmarkSSTableGet(b *testing.B) { reader, _ := NewSSTableReader("bench.sst") defer reader.Close() + // 设置 Schema + reader.SetSchema(schema) + // 性能测试 for i := 0; b.Loop(); i++ { diff --git a/table.go b/table.go index 5f699ed..0653a6f 100644 --- a/table.go +++ b/table.go @@ -1,82 +1,720 @@ package srdb import ( + "encoding/json" + "fmt" "os" "path/filepath" + "sort" + "sync" + "sync/atomic" "time" ) +const ( + DefaultMemTableSize = 64 * 1024 * 1024 // 64 MB + DefaultAutoFlushTimeout = 30 * time.Second // 30 秒无写入自动 flush +) + // Table 表 type Table struct { - name string // 表名 - dir string // 表目录 - schema *Schema // Schema - engine *Engine // Engine 实例 - createdAt int64 // 创建时间 + dir string + schema *Schema + indexManager *IndexManager + walManager *WALManager // WAL 管理器 + sstManager *SSTableManager // SST 管理器 + memtableManager *MemTableManager // MemTable 管理器 + versionSet *VersionSet // MANIFEST 管理器 + compactionManager *CompactionManager // Compaction 管理器 + seq atomic.Int64 + flushMu sync.Mutex + + // 自动 flush 相关 + autoFlushTimeout time.Duration + lastWriteTime atomic.Int64 // 最后写入时间(UnixNano) + stopAutoFlush chan struct{} } -// createTable 创建新表 -func createTable(name string, schema *Schema, db *Database) (*Table, error) { - // 创建表目录 - tableDir := filepath.Join(db.dir, name) - err := os.MkdirAll(tableDir, 0755) +// TableOptions 配置选项 +type TableOptions struct { + Dir string + MemTableSize int64 + Name string // 表名 + Fields []Field // 字段列表(可选) + AutoFlushTimeout time.Duration // 自动 flush 超时时间,0 表示禁用 +} + +// OpenTable 打开数据库 +func OpenTable(opts *TableOptions) (*Table, error) { + if opts.MemTableSize == 0 { + opts.MemTableSize = DefaultMemTableSize + } + + // 创建主目录 + err := os.MkdirAll(opts.Dir, 0755) if err != nil { - os.RemoveAll(tableDir) return nil, err } - // 创建 Engine(Engine 会自动保存 Schema 到文件) - engine, err := OpenEngine(&EngineOptions{ - Dir: tableDir, - MemTableSize: DefaultMemTableSize, - Schema: schema, - }) + // 创建子目录 + walDir := filepath.Join(opts.Dir, "wal") + sstDir := filepath.Join(opts.Dir, "sst") + idxDir := filepath.Join(opts.Dir, "idx") + + err = os.MkdirAll(walDir, 0755) + if err != nil { + return nil, err + } + err = os.MkdirAll(sstDir, 0755) + if err != nil { + return nil, err + } + err = os.MkdirAll(idxDir, 0755) if err != nil { - os.RemoveAll(tableDir) return nil, err } + // 处理 Schema + var sch *Schema + if opts.Name != "" && len(opts.Fields) > 0 { + // 从 Name 和 Fields 创建 Schema + sch = NewSchema(opts.Name, opts.Fields) + // 保存到磁盘(带校验和) + schemaPath := filepath.Join(opts.Dir, "schema.json") + schemaFile, err := NewSchemaFile(sch) + if err != nil { + return nil, fmt.Errorf("create schema file: %w", err) + } + schemaData, err := json.MarshalIndent(schemaFile, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal schema: %w", err) + } + err = os.WriteFile(schemaPath, schemaData, 0644) + if err != nil { + return nil, fmt.Errorf("write schema: %w", err) + } + } else { + // 尝试从磁盘恢复 + schemaPath := filepath.Join(opts.Dir, "schema.json") + schemaData, err := os.ReadFile(schemaPath) + if err == nil { + // 文件存在,尝试解析 + schemaFile := &SchemaFile{} + err = json.Unmarshal(schemaData, schemaFile) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal schema from %s: %w", schemaPath, err) + } + + // 验证校验和 + err = schemaFile.Verify() + if err != nil { + return nil, fmt.Errorf("failed to verify schema from %s: %w", schemaPath, err) + } + + sch = schemaFile.Schema + } else if !os.IsNotExist(err) { + // 其他读取错误 + return nil, fmt.Errorf("failed to read schema file %s: %w", schemaPath, err) + } else { + // Schema 文件不存在 + return nil, fmt.Errorf("schema is required but schema.json not found in %s", opts.Dir) + } + } + + // 强制要求 Schema + if sch == nil { + return nil, fmt.Errorf("schema is required to open table") + } + + // 创建索引管理器 + indexMgr := NewIndexManager(idxDir, sch) + + // 创建 SST Manager + sstMgr, err := NewSSTableManager(sstDir) + if err != nil { + return nil, err + } + + // 设置 Schema(用于优化编解码) + sstMgr.SetSchema(sch) + + // 创建 MemTable Manager + memMgr := NewMemTableManager(opts.MemTableSize) + + // 创建/恢复 MANIFEST + manifestDir := opts.Dir + versionSet, err := NewVersionSet(manifestDir) + if err != nil { + return nil, fmt.Errorf("create version set: %w", err) + } + + // 创建 Table(暂时不设置 WAL Manager) table := &Table{ - name: name, - dir: tableDir, - schema: schema, - engine: engine, - createdAt: time.Now().Unix(), + dir: opts.Dir, + schema: sch, + indexManager: indexMgr, + walManager: nil, // 先不设置,恢复后再创建 + sstManager: sstMgr, + memtableManager: memMgr, + versionSet: versionSet, } + // 先恢复数据(包括从 WAL 恢复) + err = table.recover() + if err != nil { + return nil, err + } + + // 恢复完成后,创建 WAL Manager 用于后续写入 + walMgr, err := NewWALManager(walDir) + if err != nil { + return nil, err + } + table.walManager = walMgr + table.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber()) + + // 创建 Compaction Manager + table.compactionManager = NewCompactionManager(sstDir, versionSet, sstMgr) + + // 设置 Schema + table.compactionManager.SetSchema(sch) + + // 启动时清理孤儿文件(崩溃恢复后的清理) + table.compactionManager.CleanupOrphanFiles() + + // 启动后台 Compaction 和垃圾回收 + table.compactionManager.Start() + + // 验证并修复索引 + table.verifyAndRepairIndexes() + + // 设置自动 flush 超时时间 + if opts.AutoFlushTimeout > 0 { + table.autoFlushTimeout = opts.AutoFlushTimeout + } else { + table.autoFlushTimeout = DefaultAutoFlushTimeout + } + table.stopAutoFlush = make(chan struct{}) + table.lastWriteTime.Store(time.Now().UnixNano()) + + // 启动自动 flush 监控 + go table.autoFlushMonitor() + return table, nil } -// openTable 打开已存在的表 -func openTable(name string, db *Database) (*Table, error) { - tableDir := filepath.Join(db.dir, name) +// Insert 插入数据 +func (t *Table) Insert(data map[string]any) error { + // 1. 验证 Schema + if err := t.schema.Validate(data); err != nil { + return NewError(ErrCodeSchemaValidationFailed, err) + } - // 打开 Engine(Engine 会自动从 schema.json 恢复 Schema) - eng, err := OpenEngine(&EngineOptions{ - Dir: tableDir, - MemTableSize: DefaultMemTableSize, - // Schema 不设置,让 Engine 自动从磁盘恢复 - }) + // 2. 生成 _seq + seq := t.seq.Add(1) + + // 3. 添加系统字段 + row := &SSTableRow{ + Seq: seq, + Time: time.Now().UnixNano(), + Data: data, + } + + // 3. 序列化 + rowData, err := json.Marshal(row) if err != nil { - return nil, err + return err } - // 从 Engine 获取 Schema - sch := eng.GetSchema() - - table := &Table{ - name: name, - dir: tableDir, - schema: sch, - engine: eng, + // 4. 写入 WAL + entry := &WALEntry{ + Type: WALEntryTypePut, + Seq: seq, + Data: rowData, + } + err = t.walManager.Append(entry) + if err != nil { + return err } - return table, nil + // 5. 写入 MemTable Manager + t.memtableManager.Put(seq, rowData) + + // 6. 添加到索引 + t.indexManager.AddToIndexes(data, seq) + + // 7. 更新最后写入时间 + t.lastWriteTime.Store(time.Now().UnixNano()) + + // 8. 检查是否需要切换 MemTable + if t.memtableManager.ShouldSwitch() { + go t.switchMemTable() + } + + return nil +} + +// Get 查询数据 +func (t *Table) Get(seq int64) (*SSTableRow, error) { + // 1. 先查 MemTable Manager (Active + Immutables) + data, found := t.memtableManager.Get(seq) + if found { + var row SSTableRow + err := json.Unmarshal(data, &row) + if err != nil { + return nil, err + } + return &row, nil + } + + // 2. 查询 SST 文件 + return t.sstManager.Get(seq) +} + +// GetPartial 按需查询数据(只读取指定字段) +func (t *Table) GetPartial(seq int64, fields []string) (*SSTableRow, error) { + // 1. 先查 MemTable Manager (Active + Immutables) + data, found := t.memtableManager.Get(seq) + if found { + var row SSTableRow + err := json.Unmarshal(data, &row) + if err != nil { + return nil, err + } + + // MemTable 中的数据已经完全解析,需要手动过滤字段 + if len(fields) > 0 { + filteredData := make(map[string]any) + for _, field := range fields { + if val, ok := row.Data[field]; ok { + filteredData[field] = val + } + } + row.Data = filteredData + } + + return &row, nil + } + + // 2. 查询 SST 文件(按需解码) + return t.sstManager.GetPartial(seq, fields) +} + +// switchMemTable 切换 MemTable +func (t *Table) switchMemTable() error { + t.flushMu.Lock() + defer t.flushMu.Unlock() + + // 1. 切换到新的 WAL + oldWALNumber, err := t.walManager.Rotate() + if err != nil { + return err + } + newWALNumber := t.walManager.GetCurrentNumber() + + // 2. 切换 MemTable (Active → Immutable) + _, immutable := t.memtableManager.Switch(newWALNumber) + + // 3. 异步 Flush Immutable + go t.flushImmutable(immutable, oldWALNumber) + + return nil +} + +// flushImmutable 将 Immutable MemTable 刷新到 SST +func (t *Table) flushImmutable(imm *ImmutableMemTable, walNumber int64) error { + // 1. 收集所有行 + var rows []*SSTableRow + iter := imm.NewIterator() + for iter.Next() { + var row SSTableRow + err := json.Unmarshal(iter.Value(), &row) + if err == nil { + rows = append(rows, &row) + } + } + + if len(rows) == 0 { + // 没有数据,直接清理 + t.walManager.Delete(walNumber) + t.memtableManager.RemoveImmutable(imm) + return nil + } + + // 2. 从 VersionSet 分配文件编号 + fileNumber := t.versionSet.AllocateFileNumber() + + // 3. 创建 SST 文件到 L0 + reader, err := t.sstManager.CreateSST(fileNumber, rows) + if err != nil { + return err + } + + // 4. 创建 FileMetadata + header := reader.GetHeader() + + // 获取文件大小 + sstPath := reader.GetPath() + fileInfo, err := os.Stat(sstPath) + if err != nil { + return fmt.Errorf("stat sst file: %w", err) + } + + fileMeta := &FileMetadata{ + FileNumber: fileNumber, + Level: 0, // Flush 到 L0 + FileSize: fileInfo.Size(), + MinKey: header.MinKey, + MaxKey: header.MaxKey, + RowCount: header.RowCount, + } + + // 5. 更新 MANIFEST + edit := NewVersionEdit() + edit.AddFile(fileMeta) + + // 持久化当前的文件编号计数器(关键修复:防止重启后文件编号重用) + // 使用 fileNumber + 1 确保并发安全,避免竞态条件 + edit.SetNextFileNumber(fileNumber + 1) + + err = t.versionSet.LogAndApply(edit) + if err != nil { + return fmt.Errorf("log and apply version edit: %w", err) + } + + // 6. 删除对应的 WAL + t.walManager.Delete(walNumber) + + // 7. 从 Immutable 列表中移除 + t.memtableManager.RemoveImmutable(imm) + + // 8. 持久化索引(防止崩溃丢失索引数据) + t.indexManager.BuildAll() + + // 9. Compaction 由后台线程负责,不在 flush 路径中触发 + // 避免同步 compaction 导致刚创建的文件立即被删除 + // t.compactionManager.MaybeCompact() + + return nil +} + +// recover 恢复数据 +func (t *Table) recover() error { + // 1. 恢复 SST 文件(SST Manager 已经在 NewManager 中恢复了) + // 只需要获取最大 seq + maxSeq := t.sstManager.GetMaxSeq() + if maxSeq > t.seq.Load() { + t.seq.Store(maxSeq) + } + + // 2. 恢复所有 WAL 文件到 MemTable Manager + walDir := filepath.Join(t.dir, "wal") + pattern := filepath.Join(walDir, "*.wal") + walFiles, err := filepath.Glob(pattern) + if err == nil && len(walFiles) > 0 { + // 按文件名排序 + sort.Strings(walFiles) + + // 依次读取每个 WAL + for _, walPath := range walFiles { + reader, err := NewWALReader(walPath) + if err != nil { + continue + } + + entries, err := reader.Read() + reader.Close() + + if err != nil { + continue + } + + // 重放 WAL 到 Active MemTable + for _, entry := range entries { + // 验证 Schema + var row SSTableRow + if err := json.Unmarshal(entry.Data, &row); err != nil { + return fmt.Errorf("failed to unmarshal row during recovery (seq=%d): %w", entry.Seq, err) + } + + // 验证 Schema + if err := t.schema.Validate(row.Data); err != nil { + return NewErrorf(ErrCodeSchemaValidationFailed, "schema validation failed during recovery (seq=%d)", entry.Seq, err) + } + + t.memtableManager.Put(entry.Seq, entry.Data) + if entry.Seq > t.seq.Load() { + t.seq.Store(entry.Seq) + } + } + } + } + + return nil +} + +// autoFlushMonitor 自动 flush 监控 +func (t *Table) autoFlushMonitor() { + ticker := time.NewTicker(t.autoFlushTimeout / 2) // 每半个超时时间检查一次 + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // 检查是否超时 + lastWrite := time.Unix(0, t.lastWriteTime.Load()) + if time.Since(lastWrite) >= t.autoFlushTimeout { + // 检查 MemTable 是否有数据 + active := t.memtableManager.GetActive() + if active != nil && active.Size() > 0 { + // 触发 flush + t.Flush() + } + } + case <-t.stopAutoFlush: + return + } + } +} + +// Flush 手动刷新 Active MemTable 到磁盘 +func (t *Table) Flush() error { + // 检查 Active MemTable 是否有数据 + active := t.memtableManager.GetActive() + if active == nil || active.Size() == 0 { + return nil // 没有数据,无需 flush + } + + // 强制切换 MemTable(switchMemTable 内部有锁) + return t.switchMemTable() +} + +// Close 关闭引擎 +func (t *Table) Close() error { + // 1. 停止自动 flush 监控(如果还在运行) + if t.stopAutoFlush != nil { + select { + case <-t.stopAutoFlush: + // 已经关闭,跳过 + default: + close(t.stopAutoFlush) + } + } + + // 2. 停止 Compaction Manager + if t.compactionManager != nil { + t.compactionManager.Stop() + } + + // 3. 刷新 Active MemTable(确保所有数据都写入磁盘) + // 检查 memtableManager 是否存在(可能已被 Destroy) + if t.memtableManager != nil { + t.Flush() + } + + // 3. 关闭 WAL Manager + if t.walManager != nil { + t.walManager.Close() + } + + // 4. 等待所有 Immutable Flush 完成 + // TODO: 添加更优雅的等待机制 + if t.memtableManager != nil { + for t.memtableManager.GetImmutableCount() > 0 { + time.Sleep(100 * time.Millisecond) + } + } + + // 5. 保存所有索引 + if t.indexManager != nil { + t.indexManager.BuildAll() + t.indexManager.Close() + } + + // 6. 关闭 VersionSet + if t.versionSet != nil { + t.versionSet.Close() + } + + // 7. 关闭 WAL Manager + if t.walManager != nil { + t.walManager.Close() + } + + // 6. 关闭 SST Manager + if t.sstManager != nil { + t.sstManager.Close() + } + + return nil +} + +// Clean 清除所有数据(保留 Table 可用) +func (t *Table) Clean() error { + t.flushMu.Lock() + defer t.flushMu.Unlock() + + // 0. 停止自动 flush 监控(临时) + if t.stopAutoFlush != nil { + close(t.stopAutoFlush) + } + + // 1. 停止 Compaction Manager + if t.compactionManager != nil { + t.compactionManager.Stop() + } + + // 2. 等待所有 Immutable Flush 完成 + for t.memtableManager.GetImmutableCount() > 0 { + time.Sleep(100 * time.Millisecond) + } + + // 3. 清空 MemTable + t.memtableManager = NewMemTableManager(DefaultMemTableSize) + + // 2. 删除所有 WAL 文件 + if t.walManager != nil { + t.walManager.Close() + walDir := filepath.Join(t.dir, "wal") + os.RemoveAll(walDir) + os.MkdirAll(walDir, 0755) + + // 重新创建 WAL Manager + walMgr, err := NewWALManager(walDir) + if err != nil { + return fmt.Errorf("recreate wal manager: %w", err) + } + t.walManager = walMgr + t.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber()) + } + + // 3. 删除所有 SST 文件 + if t.sstManager != nil { + t.sstManager.Close() + sstDir := filepath.Join(t.dir, "sst") + os.RemoveAll(sstDir) + os.MkdirAll(sstDir, 0755) + + // 重新创建 SST Manager + sstMgr, err := NewSSTableManager(sstDir) + if err != nil { + return fmt.Errorf("recreate sst manager: %w", err) + } + t.sstManager = sstMgr + // 设置 Schema + t.sstManager.SetSchema(t.schema) + } + + // 4. 删除所有索引文件 + if t.indexManager != nil { + t.indexManager.Close() + indexFiles, _ := filepath.Glob(filepath.Join(t.dir, "idx_*.sst")) + for _, f := range indexFiles { + os.Remove(f) + } + + // 重新创建 Index Manager + t.indexManager = NewIndexManager(t.dir, t.schema) + } + + // 5. 重置 MANIFEST + if t.versionSet != nil { + t.versionSet.Close() + manifestDir := t.dir + os.Remove(filepath.Join(manifestDir, "MANIFEST")) + os.Remove(filepath.Join(manifestDir, "CURRENT")) + + // 重新创建 VersionSet + versionSet, err := NewVersionSet(manifestDir) + if err != nil { + return fmt.Errorf("recreate version set: %w", err) + } + t.versionSet = versionSet + } + + // 6. 重新创建 Compaction Manager + sstDir := filepath.Join(t.dir, "sst") + t.compactionManager = NewCompactionManager(sstDir, t.versionSet, t.sstManager) + t.compactionManager.SetSchema(t.schema) + t.compactionManager.Start() + + // 7. 重置序列号 + t.seq.Store(0) + + // 8. 更新最后写入时间 + t.lastWriteTime.Store(time.Now().UnixNano()) + + // 9. 重启自动 flush 监控 + t.stopAutoFlush = make(chan struct{}) + go t.autoFlushMonitor() + + return nil +} + +// Destroy 销毁 Table 并删除所有数据文件 +func (t *Table) Destroy() error { + // 1. 先关闭 Table + if err := t.Close(); err != nil { + return fmt.Errorf("close table: %w", err) + } + + // 2. 删除整个数据目录 + if err := os.RemoveAll(t.dir); err != nil { + return fmt.Errorf("remove data directory: %w", err) + } + + // 3. 标记 Table 为不可用(将所有管理器设为 nil) + t.walManager = nil + t.sstManager = nil + t.memtableManager = nil + t.versionSet = nil + t.compactionManager = nil + t.indexManager = nil + + return nil +} + +// TableStats 统计信息 +type TableStats struct { + MemTableSize int64 + MemTableCount int + SSTCount int + TotalRows int64 +} + +// GetVersionSet 获取 VersionSet(用于高级操作) +func (t *Table) GetVersionSet() *VersionSet { + return t.versionSet +} + +// GetCompactionManager 获取 Compaction Manager(用于高级操作) +func (t *Table) GetCompactionManager() *CompactionManager { + return t.compactionManager +} + +// GetMemtableManager 获取 Memtable Manager +func (t *Table) GetMemtableManager() *MemTableManager { + return t.memtableManager +} + +// GetSSTManager 获取 SST Manager +func (t *Table) GetSSTManager() *SSTableManager { + return t.sstManager +} + +// GetMaxSeq 获取当前最大的 seq 号 +func (t *Table) GetMaxSeq() int64 { + return t.seq.Load() - 1 // seq 是下一个要分配的,所以最大的是 seq - 1 } // GetName 获取表名 func (t *Table) GetName() string { - return t.name + return t.schema.Name +} + +// GetDir 获取表目录 +func (t *Table) GetDir() string { + return t.dir } // GetSchema 获取 Schema @@ -84,71 +722,72 @@ func (t *Table) GetSchema() *Schema { return t.schema } -// Insert 插入数据 -func (t *Table) Insert(data map[string]any) error { - return t.engine.Insert(data) -} +// Stats 获取统计信息 +func (t *Table) Stats() *TableStats { + memStats := t.memtableManager.GetStats() + sstStats := t.sstManager.GetStats() -// Get 查询数据 -func (t *Table) Get(seq int64) (*SSTableRow, error) { - return t.engine.Get(seq) -} + stats := &TableStats{ + MemTableSize: memStats.TotalSize, + MemTableCount: memStats.TotalCount, + SSTCount: sstStats.FileCount, + } -// Query 创建查询构建器 -func (t *Table) Query() *QueryBuilder { - return t.engine.Query() + // 计算总行数 + stats.TotalRows = int64(memStats.TotalCount) + readers := t.sstManager.GetReaders() + for _, reader := range readers { + header := reader.GetHeader() + stats.TotalRows += header.RowCount + } + + return stats } // CreateIndex 创建索引 func (t *Table) CreateIndex(field string) error { - return t.engine.CreateIndex(field) + return t.indexManager.CreateIndex(field) } // DropIndex 删除索引 func (t *Table) DropIndex(field string) error { - return t.engine.DropIndex(field) + return t.indexManager.DropIndex(field) } // ListIndexes 列出所有索引 func (t *Table) ListIndexes() []string { - return t.engine.ListIndexes() + return t.indexManager.ListIndexes() } -// Stats 获取统计信息 -func (t *Table) Stats() *TableStats { - return t.engine.Stats() +// GetIndexMetadata 获取索引元数据 +func (t *Table) GetIndexMetadata() map[string]IndexMetadata { + return t.indexManager.GetIndexMetadata() } -// GetEngine 获取底层 Engine -func (t *Table) GetEngine() *Engine { - return t.engine +// RepairIndexes 手动修复索引 +func (t *Table) RepairIndexes() error { + return t.verifyAndRepairIndexes() } -// Close 关闭表 -func (t *Table) Close() error { - if t.engine != nil { - return t.engine.Close() +// Query 创建查询构建器 +func (t *Table) Query() *QueryBuilder { + return newQueryBuilder(t) +} + +// verifyAndRepairIndexes 验证并修复索引 +func (t *Table) verifyAndRepairIndexes() error { + // 获取当前最大 seq + currentMaxSeq := t.seq.Load() + + // 创建 getData 函数 + getData := func(seq int64) (map[string]any, error) { + row, err := t.Get(seq) + if err != nil { + return nil, err + } + return row.Data, nil } - return nil -} -// GetCreatedAt 获取表创建时间 -func (t *Table) GetCreatedAt() int64 { - return t.createdAt -} - -// Clean 清除表的所有数据(保留表结构和 Table 可用) -func (t *Table) Clean() error { - if t.engine != nil { - return t.engine.Clean() - } - return nil -} - -// Destroy 销毁表并删除所有数据文件(不从 Database 中删除) -func (t *Table) Destroy() error { - if t.engine != nil { - return t.engine.Destroy() - } - return nil + // 验证并修复 + return t.indexManager.VerifyAndRepair(currentMaxSeq, getData) } diff --git a/table_clean_test.go b/table_clean_test.go deleted file mode 100644 index 9b14726..0000000 --- a/table_clean_test.go +++ /dev/null @@ -1,314 +0,0 @@ -package srdb - -import ( - "os" - "testing" -) - -func TestTableClean(t *testing.T) { - dir := "./test_table_clean_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("users", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"}, - {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, - }) - - table, err := db.CreateTable("users", schema) - if err != nil { - t.Fatal(err) - } - - // 2. 插入数据 - for i := 0; i < 100; i++ { - err := table.Insert(map[string]any{ - "id": int64(i), - "name": "user" + string(rune(i)), - }) - if err != nil { - t.Fatal(err) - } - } - - // 3. 验证数据存在 - stats := table.Stats() - t.Logf("Before Clean: %d rows", stats.TotalRows) - - if stats.TotalRows == 0 { - t.Error("Expected data in table") - } - - // 4. 清除数据 - err = table.Clean() - if err != nil { - t.Fatal(err) - } - - // 5. 验证数据已清除 - stats = table.Stats() - t.Logf("After Clean: %d rows", stats.TotalRows) - - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) - } - - // 6. 验证表仍然可用 - err = table.Insert(map[string]any{ - "id": int64(100), - "name": "new_user", - }) - if err != nil { - t.Fatal(err) - } - - stats = table.Stats() - if stats.TotalRows != 1 { - t.Errorf("Expected 1 row after insert, got %d", stats.TotalRows) - } -} - -func TestTableDestroy(t *testing.T) { - dir := "./test_table_destroy_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - }) - - table, err := db.CreateTable("test", schema) - if err != nil { - t.Fatal(err) - } - - // 2. 插入数据 - for i := 0; i < 50; i++ { - table.Insert(map[string]any{"id": int64(i)}) - } - - // 3. 验证数据存在 - stats := table.Stats() - t.Logf("Before Destroy: %d rows", stats.TotalRows) - - if stats.TotalRows == 0 { - t.Error("Expected data in table") - } - - // 4. 获取表目录路径 - tableDir := table.dir - - // 5. 销毁表 - err = table.Destroy() - if err != nil { - t.Fatal(err) - } - - // 6. 验证表目录已删除 - if _, err := os.Stat(tableDir); !os.IsNotExist(err) { - t.Error("Table directory should be deleted") - } - - // 7. 注意:Table.Destroy() 只删除文件,不从 Database 中删除 - // 表仍然在 Database 的元数据中,但文件已被删除 - tables := db.ListTables() - found := false - for _, name := range tables { - if name == "test" { - found = true - break - } - } - if !found { - t.Error("Table should still be in database metadata (use Database.DestroyTable to remove from metadata)") - } -} - -func TestTableCleanWithIndex(t *testing.T) { - dir := "./test_table_clean_index_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("users", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"}, - {Name: "email", Type: FieldTypeString, Indexed: true, Comment: "Email"}, - {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, - }) - - table, err := db.CreateTable("users", schema) - if err != nil { - t.Fatal(err) - } - - // 2. 创建索引 - err = table.CreateIndex("id") - if err != nil { - t.Fatal(err) - } - - err = table.CreateIndex("email") - if err != nil { - t.Fatal(err) - } - - // 3. 插入数据 - for i := 0; i < 50; i++ { - table.Insert(map[string]any{ - "id": int64(i), - "email": "user" + string(rune(i)) + "@example.com", - "name": "User " + string(rune(i)), - }) - } - - // 4. 验证索引存在 - indexes := table.ListIndexes() - if len(indexes) != 2 { - t.Errorf("Expected 2 indexes, got %d", len(indexes)) - } - - // 5. 清除数据 - err = table.Clean() - if err != nil { - t.Fatal(err) - } - - // 6. 验证数据已清除 - stats := table.Stats() - if stats.TotalRows != 0 { - t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) - } - - // 7. 验证索引已被清除(Clean 会删除索引数据) - indexes = table.ListIndexes() - if len(indexes) != 0 { - t.Logf("Note: Indexes were cleared (expected behavior), got %d", len(indexes)) - } - - // 8. 重新创建索引 - table.CreateIndex("id") - table.CreateIndex("email") - - // 9. 验证可以继续插入数据 - err = table.Insert(map[string]any{ - "id": int64(100), - "email": "new@example.com", - "name": "New User", - }) - if err != nil { - t.Fatal(err) - } - - stats = table.Stats() - if stats.TotalRows != 1 { - t.Errorf("Expected 1 row, got %d", stats.TotalRows) - } -} - -func TestTableCleanAndQuery(t *testing.T) { - dir := "./test_table_clean_query_data" - defer os.RemoveAll(dir) - - // 1. 创建数据库和表 - db, err := Open(dir) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - schema := NewSchema("test", []Field{ - {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, - {Name: "status", Type: FieldTypeString, Indexed: false, Comment: "Status"}, - }) - - table, err := db.CreateTable("test", schema) - if err != nil { - t.Fatal(err) - } - - // 2. 插入数据 - for i := 0; i < 30; i++ { - table.Insert(map[string]any{ - "id": int64(i), - "status": "active", - }) - } - - // 3. 查询数据 - rows, err := table.Query().Eq("status", "active").Rows() - if err != nil { - t.Fatal(err) - } - - count := 0 - for rows.Next() { - count++ - } - rows.Close() - - t.Logf("Before Clean: found %d rows", count) - if count != 30 { - t.Errorf("Expected 30 rows, got %d", count) - } - - // 4. 清除数据 - err = table.Clean() - if err != nil { - t.Fatal(err) - } - - // 5. 再次查询 - rows, err = table.Query().Eq("status", "active").Rows() - if err != nil { - t.Fatal(err) - } - - count = 0 - for rows.Next() { - count++ - } - rows.Close() - - t.Logf("After Clean: found %d rows", count) - if count != 0 { - t.Errorf("Expected 0 rows after clean, got %d", count) - } - - // 6. 插入新数据并查询 - table.Insert(map[string]any{ - "id": int64(100), - "status": "active", - }) - - rows, err = table.Query().Eq("status", "active").Rows() - if err != nil { - t.Fatal(err) - } - - count = 0 - for rows.Next() { - count++ - } - rows.Close() - - if count != 1 { - t.Errorf("Expected 1 row, got %d", count) - } -} diff --git a/engine_test.go b/table_test.go similarity index 62% rename from engine_test.go rename to table_test.go index 9b9a3fa..b15bf64 100644 --- a/engine_test.go +++ b/table_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "fmt" "os" + "slices" "strings" "sync" "sync/atomic" @@ -11,20 +12,27 @@ import ( "time" ) -func TestEngine(t *testing.T) { +func TestTable(t *testing.T) { // 1. 创建引擎 dir := "test_db" os.RemoveAll(dir) defer os.RemoveAll(dir) - engine, err := OpenEngine(&EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "用户名"}, + {Name: "age", Type: FieldTypeInt64, Indexed: false, Comment: "年龄"}, + }) + + table, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 1024, // 1 KB,方便触发 Flush + Name: schema.Name, + Fields: schema.Fields, }) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 2. 插入数据 for i := 1; i <= 100; i++ { @@ -32,7 +40,7 @@ func TestEngine(t *testing.T) { "name": fmt.Sprintf("user_%d", i), "age": 20 + i%50, } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { t.Fatal(err) } @@ -45,7 +53,7 @@ func TestEngine(t *testing.T) { // 3. 查询数据 for i := int64(1); i <= 100; i++ { - row, err := engine.Get(i) + row, err := table.Get(i) if err != nil { t.Errorf("Failed to get key %d: %v", i, err) continue @@ -56,7 +64,7 @@ func TestEngine(t *testing.T) { } // 4. 统计信息 - stats := engine.Stats() + stats := table.Stats() t.Logf("Stats: MemTable=%d rows, SST=%d files, Total=%d rows", stats.MemTableCount, stats.SSTCount, stats.TotalRows) @@ -67,45 +75,53 @@ func TestEngine(t *testing.T) { t.Log("All tests passed!") } -func TestEngineRecover(t *testing.T) { +func TestTableRecover(t *testing.T) { dir := "test_recover" os.RemoveAll(dir) defer os.RemoveAll(dir) + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeInt64, Indexed: false, Comment: "值"}, + }) + // 1. 创建引擎并插入数据 - engine, err := OpenEngine(&EngineOptions{ + table, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, // 10 MB,不会触发 Flush + Name: schema.Name, + Fields: schema.Fields, }) if err != nil { t.Fatal(err) } for i := 1; i <= 50; i++ { - data := map[string]interface{}{ + data := map[string]any{ "value": i, } - engine.Insert(data) + table.Insert(data) } t.Log("Inserted 50 rows") // 2. 关闭引擎 (模拟崩溃前) - engine.Close() + table.Close() // 3. 重新打开引擎 (恢复) - engine2, err := OpenEngine(&EngineOptions{ + table2, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, + Name: schema.Name, + Fields: schema.Fields, }) if err != nil { t.Fatal(err) } - defer engine2.Close() + defer table2.Close() // 4. 验证数据 for i := int64(1); i <= 50; i++ { - row, err := engine2.Get(i) + row, err := table2.Get(i) if err != nil { t.Errorf("Failed to get key %d after recover: %v", i, err) } @@ -114,7 +130,7 @@ func TestEngineRecover(t *testing.T) { } } - stats := engine2.Stats() + stats := table2.Stats() if stats.TotalRows != 50 { t.Errorf("Expected 50 rows after recover, got %d", stats.TotalRows) } @@ -122,32 +138,38 @@ func TestEngineRecover(t *testing.T) { t.Log("Recover test passed!") } -func TestEngineFlush(t *testing.T) { +func TestTableFlush(t *testing.T) { dir := "test_flush" os.RemoveAll(dir) defer os.RemoveAll(dir) - engine, err := OpenEngine(&EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "data", Type: FieldTypeString, Indexed: false, Comment: "数据"}, + }) + + table, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 1024, // 1 KB + Name: schema.Name, + Fields: schema.Fields, }) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 插入足够多的数据触发 Flush for i := 1; i <= 200; i++ { data := map[string]any{ "data": fmt.Sprintf("value_%d", i), } - engine.Insert(data) + table.Insert(data) } // 等待 Flush time.Sleep(500 * time.Millisecond) - stats := engine.Stats() + stats := table.Stats() t.Logf("After flush: MemTable=%d, SST=%d, Total=%d", stats.MemTableCount, stats.SSTCount, stats.TotalRows) @@ -157,7 +179,7 @@ func TestEngineFlush(t *testing.T) { // 验证所有数据都能查到 for i := int64(1); i <= 200; i++ { - _, err := engine.Get(i) + _, err := table.Get(i) if err != nil { t.Errorf("Failed to get key %d after flush: %v", i, err) } @@ -166,48 +188,60 @@ func TestEngineFlush(t *testing.T) { t.Log("Flush test passed!") } -func BenchmarkEngineInsert(b *testing.B) { +func BenchmarkTableInsert(b *testing.B) { dir := "bench_insert" os.RemoveAll(dir) defer os.RemoveAll(dir) - engine, _ := OpenEngine(&EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeInt64, Indexed: false, Comment: "值"}, + }) + + table, _ := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 100 * 1024 * 1024, // 100 MB + Name: schema.Name, + Fields: schema.Fields, }) - defer engine.Close() + defer table.Close() data := map[string]any{ "value": 123, } for b.Loop() { - engine.Insert(data) + table.Insert(data) } } -func BenchmarkEngineGet(b *testing.B) { +func BenchmarkTableGet(b *testing.B) { dir := "bench_get" os.RemoveAll(dir) defer os.RemoveAll(dir) - engine, _ := OpenEngine(&EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "value", Type: FieldTypeInt64, Indexed: false, Comment: "值"}, + }) + + table, _ := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 100 * 1024 * 1024, + Name: schema.Name, + Fields: schema.Fields, }) - defer engine.Close() + defer table.Close() // 预先插入数据 for i := 1; i <= 10000; i++ { data := map[string]any{ "value": i, } - engine.Insert(data) + table.Insert(data) } for i := 0; b.Loop(); i++ { key := int64(i%10000 + 1) - engine.Get(key) + table.Get(key) } } @@ -215,16 +249,24 @@ func BenchmarkEngineGet(b *testing.B) { func TestHighConcurrencyWrite(t *testing.T) { tmpDir := t.TempDir() - opts := &EngineOptions{ + // Note: This test uses []byte payload - we create a minimal schema + // Schema validation accepts []byte as it gets JSON-marshaled + schema := NewSchema("test", []Field{ + {Name: "worker_id", Type: FieldTypeInt64, Indexed: false, Comment: "Worker ID"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 64 * 1024 * 1024, // 64MB + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 测试配置 const ( @@ -262,7 +304,7 @@ func TestHighConcurrencyWrite(t *testing.T) { "timestamp": time.Now().Unix(), } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { totalErrors.Add(1) t.Logf("Worker %d, Row %d: Insert failed: %v", workerID, j, err) @@ -308,8 +350,8 @@ func TestHighConcurrencyWrite(t *testing.T) { time.Sleep(2 * time.Second) // 验证数据完整性 - stats := engine.Stats() - t.Logf("\nEngine 状态:") + stats := table.Stats() + t.Logf("\nTable 状态:") t.Logf(" 总行数: %d", stats.TotalRows) t.Logf(" SST 文件数: %d", stats.SSTCount) t.Logf(" MemTable 行数: %d", stats.MemTableCount) @@ -323,16 +365,23 @@ func TestHighConcurrencyWrite(t *testing.T) { func TestConcurrentReadWrite(t *testing.T) { tmpDir := t.TempDir() - opts := &EngineOptions{ + // Note: This test uses []byte data - we create a minimal schema + schema := NewSchema("test", []Field{ + {Name: "writer_id", Type: FieldTypeInt64, Indexed: false, Comment: "Writer ID"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 32 * 1024 * 1024, // 32MB + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() const ( numWriters = 20 @@ -369,7 +418,7 @@ func TestConcurrentReadWrite(t *testing.T) { "timestamp": time.Now().UnixNano(), } - err := engine.Insert(payload) + err := table.Insert(payload) if err == nil { writeCount.Add(1) } @@ -393,7 +442,7 @@ func TestConcurrentReadWrite(t *testing.T) { default: // 随机读取 seq := int64(readerID*100 + 1) - _, err := engine.Get(seq) + _, err := table.Get(seq) if err == nil { readCount.Add(1) } else { @@ -422,8 +471,8 @@ func TestConcurrentReadWrite(t *testing.T) { t.Logf("读取次数: %d (%.2f 次/秒)", reads, float64(reads)/duration.Seconds()) t.Logf("读取失败: %d", errors) - stats := engine.Stats() - t.Logf("\nEngine 状态:") + stats := table.Stats() + t.Logf("\nTable 状态:") t.Logf(" 总行数: %d", stats.TotalRows) t.Logf(" SST 文件数: %d", stats.SSTCount) } @@ -435,12 +484,19 @@ func TestPowerFailureRecovery(t *testing.T) { // 第一阶段:写入数据并模拟崩溃 t.Log("=== 阶段 1: 写入数据 ===") - opts := &EngineOptions{ + // Note: This test uses []byte data - we create a minimal schema + schema := NewSchema("test", []Field{ + {Name: "batch", Type: FieldTypeInt64, Indexed: false, Comment: "Batch number"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 4 * 1024 * 1024, // 4MB + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } @@ -465,18 +521,18 @@ func TestPowerFailureRecovery(t *testing.T) { "timestamp": time.Now().Unix(), } - err := engine.Insert(payload) + err := table.Insert(payload) if err != nil { t.Fatalf("Insert failed: %v", err) } - seq := engine.seq.Load() + seq := table.seq.Load() insertedSeqs = append(insertedSeqs, seq) } // 每批后触发 Flush if batch%3 == 0 { - engine.switchMemTable() + table.switchMemTable() time.Sleep(100 * time.Millisecond) } @@ -487,27 +543,27 @@ func TestPowerFailureRecovery(t *testing.T) { t.Logf("总共插入: %d 行", totalInserted) // 获取崩溃前的状态 - statsBefore := engine.Stats() + statsBefore := table.Stats() t.Logf("崩溃前状态: 总行数=%d, SST文件=%d, MemTable行数=%d", statsBefore.TotalRows, statsBefore.SSTCount, statsBefore.MemTableCount) // 模拟崩溃:直接关闭(不等待 Flush 完成) t.Log("\n=== 模拟断电崩溃 ===") - engine.Close() + table.Close() // 第二阶段:恢复并验证数据 t.Log("\n=== 阶段 2: 恢复数据 ===") - engineRecovered, err := OpenEngine(opts) + tableRecovered, err := OpenTable(opts) if err != nil { t.Fatalf("恢复失败: %v", err) } - defer engineRecovered.Close() + defer tableRecovered.Close() // 等待恢复完成 time.Sleep(500 * time.Millisecond) - statsAfter := engineRecovered.Stats() + statsAfter := tableRecovered.Stats() t.Logf("恢复后状态: 总行数=%d, SST文件=%d, MemTable行数=%d", statsAfter.TotalRows, statsAfter.SSTCount, statsAfter.MemTableCount) @@ -519,7 +575,7 @@ func TestPowerFailureRecovery(t *testing.T) { corrupted := 0 for i, seq := range insertedSeqs { - row, err := engineRecovered.Get(seq) + row, err := tableRecovered.Get(seq) if err != nil { missing++ if i < len(insertedSeqs)/2 { @@ -564,12 +620,19 @@ func TestPowerFailureRecovery(t *testing.T) { func TestCrashDuringCompaction(t *testing.T) { tmpDir := t.TempDir() - opts := &EngineOptions{ + // Note: This test uses []byte data - we create a minimal schema + schema := NewSchema("test", []Field{ + {Name: "index", Type: FieldTypeInt64, Indexed: false, Comment: "Index"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 1024, // 很小,快速触发 Flush + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } @@ -588,7 +651,7 @@ func TestCrashDuringCompaction(t *testing.T) { "data": data, } - err := engine.Insert(payload) + err := table.Insert(payload) if err != nil { t.Fatal(err) } @@ -601,7 +664,7 @@ func TestCrashDuringCompaction(t *testing.T) { // 等待一些 Flush 完成 time.Sleep(500 * time.Millisecond) - version := engine.versionSet.GetCurrent() + version := table.versionSet.GetCurrent() l0Count := version.GetLevelFileCount(0) t.Logf("L0 文件数: %d", l0Count) @@ -609,7 +672,7 @@ func TestCrashDuringCompaction(t *testing.T) { if l0Count >= 4 { t.Log("触发 Compaction...") go func() { - engine.compactionManager.TriggerCompaction() + table.compactionManager.TriggerCompaction() }() // 等待 Compaction 开始 @@ -619,18 +682,18 @@ func TestCrashDuringCompaction(t *testing.T) { } // 直接关闭(模拟崩溃) - engine.Close() + table.Close() // 恢复 t.Log("\n=== 恢复数据库 ===") - engineRecovered, err := OpenEngine(opts) + tableRecovered, err := OpenTable(opts) if err != nil { t.Fatalf("恢复失败: %v", err) } - defer engineRecovered.Close() + defer tableRecovered.Close() // 验证数据完整性 - stats := engineRecovered.Stats() + stats := tableRecovered.Stats() t.Logf("恢复后: 总行数=%d, SST文件=%d", stats.TotalRows, stats.SSTCount) // 随机验证一些数据 @@ -638,7 +701,7 @@ func TestCrashDuringCompaction(t *testing.T) { verified := 0 for i := 1; i <= 100; i++ { seq := int64(i) - _, err := engineRecovered.Get(seq) + _, err := tableRecovered.Get(seq) if err == nil { verified++ } @@ -657,16 +720,23 @@ func TestCrashDuringCompaction(t *testing.T) { func TestLargeDataIntegrity(t *testing.T) { tmpDir := t.TempDir() - opts := &EngineOptions{ + // Note: This test uses []byte data - we create a minimal schema + schema := NewSchema("test", []Field{ + {Name: "size", Type: FieldTypeInt64, Indexed: false, Comment: "Size"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 64 * 1024 * 1024, // 64MB + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 测试不同大小的数据 testSizes := []int{ @@ -693,12 +763,12 @@ func TestLargeDataIntegrity(t *testing.T) { "data": data, } - err := engine.Insert(payload) + err := table.Insert(payload) if err != nil { t.Fatalf("插入失败 (size=%d, index=%d): %v", size, i, err) } - seq := engine.seq.Load() + seq := table.seq.Load() insertedSeqs = append(insertedSeqs, seq) t.Logf("插入: Seq=%d, Size=%d KB", seq, size/1024) @@ -716,7 +786,7 @@ func TestLargeDataIntegrity(t *testing.T) { successCount := 0 for i, seq := range insertedSeqs { - row, err := engine.Get(seq) + row, err := table.Get(seq) if err != nil { t.Errorf("读取失败 (Seq=%d): %v", seq, err) continue @@ -743,7 +813,7 @@ func TestLargeDataIntegrity(t *testing.T) { successRate := float64(successCount) / float64(totalInserted) * 100 - stats := engine.Stats() + stats := table.Stats() t.Logf("\n=== 测试结果 ===") t.Logf("插入总数: %d", totalInserted) t.Logf("成功读取: %d (%.2f%%)", successCount, successRate) @@ -761,16 +831,23 @@ func TestLargeDataIntegrity(t *testing.T) { func BenchmarkConcurrentWrites(b *testing.B) { tmpDir := b.TempDir() - opts := &EngineOptions{ + // Note: This benchmark uses []byte data - we create a minimal schema + schema := NewSchema("test", []Field{ + {Name: "timestamp", Type: FieldTypeInt64, Indexed: false, Comment: "Timestamp"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 64 * 1024 * 1024, + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { b.Fatal(err) } - defer engine.Close() + defer table.Close() const ( numWorkers = 10 @@ -788,7 +865,7 @@ func BenchmarkConcurrentWrites(b *testing.B) { "timestamp": time.Now().UnixNano(), } - err := engine.Insert(payload) + err := table.Insert(payload) if err != nil { b.Error(err) } @@ -797,26 +874,34 @@ func BenchmarkConcurrentWrites(b *testing.B) { b.StopTimer() - stats := engine.Stats() + stats := table.Stats() b.Logf("总行数: %d, SST 文件数: %d", stats.TotalRows, stats.SSTCount) } -// TestEngineWithCompaction 测试 Engine 的 Compaction 功能 -func TestEngineWithCompaction(t *testing.T) { +// TestTableWithCompaction 测试 Table 的 Compaction 功能 +func TestTableWithCompaction(t *testing.T) { // 创建临时目录 tmpDir := t.TempDir() - // 打开 Engine - opts := &EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "batch", Type: FieldTypeInt64, Indexed: false, Comment: "批次"}, + {Name: "index", Type: FieldTypeInt64, Indexed: false, Comment: "索引"}, + {Name: "value", Type: FieldTypeString, Indexed: false, Comment: "值"}, + }) + + // 打开 Table + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 1024, // 小的 MemTable 以便快速触发 Flush + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 插入大量数据,触发多次 Flush const numBatches = 10 @@ -830,14 +915,14 @@ func TestEngineWithCompaction(t *testing.T) { "value": fmt.Sprintf("data-%d-%d", batch, i), } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { t.Fatalf("Insert failed: %v", err) } } // 强制 Flush - err = engine.switchMemTable() + err = table.switchMemTable() if err != nil { t.Fatalf("Switch MemTable failed: %v", err) } @@ -847,12 +932,12 @@ func TestEngineWithCompaction(t *testing.T) { } // 等待所有 Immutable Flush 完成 - for engine.memtableManager.GetImmutableCount() > 0 { + for table.memtableManager.GetImmutableCount() > 0 { time.Sleep(100 * time.Millisecond) } // 检查 Version 状态 - version := engine.versionSet.GetCurrent() + version := table.versionSet.GetCurrent() l0Count := version.GetLevelFileCount(0) t.Logf("L0 files: %d", l0Count) @@ -861,7 +946,7 @@ func TestEngineWithCompaction(t *testing.T) { } // 获取 Level 统计信息 - levelStats := engine.compactionManager.GetLevelStats() + levelStats := table.compactionManager.GetLevelStats() for _, stat := range levelStats { level := stat["level"].(int) fileCount := stat["file_count"].(int) @@ -876,14 +961,14 @@ func TestEngineWithCompaction(t *testing.T) { // 手动触发 Compaction if l0Count >= 4 { t.Log("Triggering manual compaction...") - err = engine.compactionManager.TriggerCompaction() + err = table.compactionManager.TriggerCompaction() if err != nil { t.Logf("Compaction: %v", err) } else { t.Log("Compaction completed") // 检查 Compaction 后的状态 - version = engine.versionSet.GetCurrent() + version = table.versionSet.GetCurrent() newL0Count := version.GetLevelFileCount(0) l1Count := version.GetLevelFileCount(1) @@ -900,40 +985,48 @@ func TestEngineWithCompaction(t *testing.T) { } // 验证数据完整性 - stats := engine.Stats() - t.Logf("Engine stats: %d rows, %d SST files", stats.TotalRows, stats.SSTCount) + stats := table.Stats() + t.Logf("Table stats: %d rows, %d SST files", stats.TotalRows, stats.SSTCount) // 读取一些数据验证 for batch := range 3 { for i := range 10 { seq := int64(batch*rowsPerBatch + i + 1) - row, err := engine.Get(seq) + row, err := table.Get(seq) if err != nil { t.Errorf("Get(%d) failed: %v", seq, err) continue } - if row.Data["batch"].(float64) != float64(batch) { + if row.Data["batch"].(int64) != int64(batch) { t.Errorf("Expected batch %d, got %v", batch, row.Data["batch"]) } } } } -// TestEngineCompactionMerge 测试 Compaction 的合并功能 -func TestEngineCompactionMerge(t *testing.T) { +// TestTableCompactionMerge 测试 Compaction 的合并功能 +func TestTableCompactionMerge(t *testing.T) { tmpDir := t.TempDir() - opts := &EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "batch", Type: FieldTypeInt64, Indexed: false, Comment: "批次"}, + {Name: "index", Type: FieldTypeInt64, Indexed: false, Comment: "索引"}, + {Name: "value", Type: FieldTypeString, Indexed: false, Comment: "值"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 512, // 很小的 MemTable + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 插入数据(Append-Only 模式) const numBatches = 5 @@ -948,7 +1041,7 @@ func TestEngineCompactionMerge(t *testing.T) { "value": fmt.Sprintf("v%d-%d", batch, i), } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { t.Fatal(err) } @@ -956,7 +1049,7 @@ func TestEngineCompactionMerge(t *testing.T) { } // 每批后 Flush - err = engine.switchMemTable() + err = table.switchMemTable() if err != nil { t.Fatal(err) } @@ -965,22 +1058,22 @@ func TestEngineCompactionMerge(t *testing.T) { } // 等待所有 Flush 完成 - for engine.memtableManager.GetImmutableCount() > 0 { + for table.memtableManager.GetImmutableCount() > 0 { time.Sleep(100 * time.Millisecond) } // 记录 Compaction 前的文件数 - version := engine.versionSet.GetCurrent() + version := table.versionSet.GetCurrent() beforeL0 := version.GetLevelFileCount(0) t.Logf("Before compaction: L0 has %d files", beforeL0) // 触发 Compaction if beforeL0 >= 4 { - err = engine.compactionManager.TriggerCompaction() + err = table.compactionManager.TriggerCompaction() if err != nil { t.Logf("Compaction: %v", err) } else { - version = engine.versionSet.GetCurrent() + version = table.versionSet.GetCurrent() afterL0 := version.GetLevelFileCount(0) afterL1 := version.GetLevelFileCount(1) t.Logf("After compaction: L0 has %d files, L1 has %d files", afterL0, afterL1) @@ -991,13 +1084,13 @@ func TestEngineCompactionMerge(t *testing.T) { for batch := range 2 { for i := range 5 { seq := int64(batch*rowsPerBatch + i + 1) - row, err := engine.Get(seq) + row, err := table.Get(seq) if err != nil { t.Errorf("Get(%d) failed: %v", seq, err) continue } - actualBatch := int(row.Data["batch"].(float64)) + actualBatch := int(row.Data["batch"].(int64)) if actualBatch != batch { t.Errorf("Seq %d: expected batch %d, got %d", seq, batch, actualBatch) } @@ -1005,7 +1098,7 @@ func TestEngineCompactionMerge(t *testing.T) { } // 验证总行数 - stats := engine.Stats() + stats := table.Stats() if stats.TotalRows != int64(totalRows) { t.Errorf("Expected %d total rows, got %d", totalRows, stats.TotalRows) } @@ -1013,24 +1106,31 @@ func TestEngineCompactionMerge(t *testing.T) { t.Logf("Data integrity verified: %d rows", totalRows) } -// TestEngineBackgroundCompaction 测试后台自动 Compaction -func TestEngineBackgroundCompaction(t *testing.T) { +// TestTableBackgroundCompaction 测试后台自动 Compaction +func TestTableBackgroundCompaction(t *testing.T) { if testing.Short() { t.Skip("Skipping background compaction test in short mode") } tmpDir := t.TempDir() - opts := &EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "batch", Type: FieldTypeInt64, Indexed: false, Comment: "批次"}, + {Name: "index", Type: FieldTypeInt64, Indexed: false, Comment: "索引"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 512, + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { t.Fatal(err) } - defer engine.Close() + defer table.Close() // 插入数据触发多次 Flush const numBatches = 8 @@ -1043,13 +1143,13 @@ func TestEngineBackgroundCompaction(t *testing.T) { "index": i, } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { t.Fatal(err) } } - err = engine.switchMemTable() + err = table.switchMemTable() if err != nil { t.Fatal(err) } @@ -1058,12 +1158,12 @@ func TestEngineBackgroundCompaction(t *testing.T) { } // 等待 Flush 完成 - for engine.memtableManager.GetImmutableCount() > 0 { + for table.memtableManager.GetImmutableCount() > 0 { time.Sleep(100 * time.Millisecond) } // 记录初始状态 - version := engine.versionSet.GetCurrent() + version := table.versionSet.GetCurrent() initialL0 := version.GetLevelFileCount(0) t.Logf("Initial L0 files: %d", initialL0) @@ -1076,7 +1176,7 @@ func TestEngineBackgroundCompaction(t *testing.T) { time.Sleep(checkInterval) waited += checkInterval - version = engine.versionSet.GetCurrent() + version = table.versionSet.GetCurrent() currentL0 := version.GetLevelFileCount(0) currentL1 := version.GetLevelFileCount(1) @@ -1087,7 +1187,7 @@ func TestEngineBackgroundCompaction(t *testing.T) { t.Logf("Background compaction detected!") // 获取 Compaction 统计 - stats := engine.compactionManager.GetStats() + stats := table.compactionManager.GetStats() t.Logf("Compaction stats: %v", stats) return @@ -1097,20 +1197,27 @@ func TestEngineBackgroundCompaction(t *testing.T) { t.Log("No background compaction detected within timeout (this is OK if L0 < 4 files)") } -// BenchmarkEngineWithCompaction 性能测试 -func BenchmarkEngineWithCompaction(b *testing.B) { +// BenchmarkTableWithCompaction 性能测试 +func BenchmarkTableWithCompaction(b *testing.B) { tmpDir := b.TempDir() - opts := &EngineOptions{ + schema := NewSchema("test", []Field{ + {Name: "index", Type: FieldTypeInt64, Indexed: false, Comment: "索引"}, + {Name: "value", Type: FieldTypeString, Indexed: false, Comment: "值"}, + }) + + opts := &TableOptions{ Dir: tmpDir, MemTableSize: 64 * 1024, // 64KB + Name: schema.Name, + Fields: schema.Fields, } - engine, err := OpenEngine(opts) + table, err := OpenTable(opts) if err != nil { b.Fatal(err) } - defer engine.Close() + defer table.Close() for i := 0; b.Loop(); i++ { data := map[string]any{ @@ -1118,7 +1225,7 @@ func BenchmarkEngineWithCompaction(b *testing.B) { "value": fmt.Sprintf("benchmark-data-%d", i), } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { b.Fatal(err) } @@ -1127,20 +1234,20 @@ func BenchmarkEngineWithCompaction(b *testing.B) { b.StopTimer() // 等待所有 Flush 完成 - for engine.memtableManager.GetImmutableCount() > 0 { + for table.memtableManager.GetImmutableCount() > 0 { time.Sleep(10 * time.Millisecond) } // 报告统计信息 - version := engine.versionSet.GetCurrent() + version := table.versionSet.GetCurrent() b.Logf("Final state: L0=%d files, L1=%d files, Total=%d files", version.GetLevelFileCount(0), version.GetLevelFileCount(1), version.GetFileCount()) } -// TestEngineSchemaRecover 测试 Schema 恢复 -func TestEngineSchemaRecover(t *testing.T) { +// TestTableSchemaRecover 测试 Schema 恢复 +func TestTableSchemaRecover(t *testing.T) { dir := "test_schema_recover" os.RemoveAll(dir) defer os.RemoveAll(dir) @@ -1153,10 +1260,10 @@ func TestEngineSchemaRecover(t *testing.T) { }) // 1. 创建引擎并插入数据(带 Schema) - engine, err := OpenEngine(&EngineOptions{ + table, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, // 10 MB,不会触发 Flush - Schema: s, + Name: s.Name, Fields: s.Fields, }) if err != nil { t.Fatal(err) @@ -1169,7 +1276,7 @@ func TestEngineSchemaRecover(t *testing.T) { "age": 20 + i%50, "email": fmt.Sprintf("user%d@example.com", i), } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { t.Fatalf("Failed to insert valid data: %v", err) } @@ -1178,20 +1285,20 @@ func TestEngineSchemaRecover(t *testing.T) { t.Log("Inserted 50 rows with schema") // 2. 关闭引擎 - engine.Close() + table.Close() // 3. 重新打开引擎(带 Schema,应该成功恢复) - engine2, err := OpenEngine(&EngineOptions{ + table2, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, - Schema: s, + Name: s.Name, Fields: s.Fields, }) if err != nil { t.Fatalf("Failed to recover with schema: %v", err) } // 验证数据 - row, err := engine2.Get(1) + row, err := table2.Get(1) if err != nil { t.Fatalf("Failed to get row after recovery: %v", err) } @@ -1207,21 +1314,28 @@ func TestEngineSchemaRecover(t *testing.T) { t.Error("Missing field 'age'") } - engine2.Close() + table2.Close() t.Log("Schema recovery test passed!") } -// TestEngineSchemaRecoverInvalid 测试当 WAL 中有不符合 Schema 的数据时恢复失败 -func TestEngineSchemaRecoverInvalid(t *testing.T) { +// TestTableSchemaRecoverInvalid 测试当 WAL 中有不符合 Schema 的数据时恢复失败 +func TestTableSchemaRecoverInvalid(t *testing.T) { dir := "test_schema_recover_invalid" os.RemoveAll(dir) defer os.RemoveAll(dir) + schema := NewSchema("test", []Field{ + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "用户名"}, + {Name: "age", Type: FieldTypeString, Indexed: false, Comment: "年龄字符串"}, + }) + // 1. 先不带 Schema 插入一些数据 - engine, err := OpenEngine(&EngineOptions{ + table, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, // 大容量,确保不会触发 Flush + Name: schema.Name, + Fields: schema.Fields, }) if err != nil { t.Fatal(err) @@ -1233,25 +1347,25 @@ func TestEngineSchemaRecoverInvalid(t *testing.T) { "name": fmt.Sprintf("user_%d", i), "age": "invalid_age", // 这是字符串,但后续 Schema 要求 int64 } - err := engine.Insert(data) + err := table.Insert(data) if err != nil { t.Fatalf("Failed to insert data: %v", err) } } // 2. 停止后台任务但不 Flush(模拟崩溃) - if engine.compactionManager != nil { - engine.compactionManager.Stop() + if table.compactionManager != nil { + table.compactionManager.Stop() } // 直接关闭资源,但不 Flush MemTable - if engine.walManager != nil { - engine.walManager.Close() + if table.walManager != nil { + table.walManager.Close() } - if engine.versionSet != nil { - engine.versionSet.Close() + if table.versionSet != nil { + table.versionSet.Close() } - if engine.sstManager != nil { - engine.sstManager.Close() + if table.sstManager != nil { + table.sstManager.Close() } // 3. 创建 Schema,age 字段要求 int64 @@ -1261,13 +1375,13 @@ func TestEngineSchemaRecoverInvalid(t *testing.T) { }) // 4. 尝试用 Schema 打开引擎,应该失败 - engine2, err := OpenEngine(&EngineOptions{ + table2, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, - Schema: s, + Name: s.Name, Fields: s.Fields, }) if err == nil { - engine2.Close() + table2.Close() t.Fatal("Expected recovery to fail with invalid schema, but it succeeded") } @@ -1279,8 +1393,8 @@ func TestEngineSchemaRecoverInvalid(t *testing.T) { t.Log("Invalid schema recovery test passed!") } -// TestEngineAutoRecoverSchema 测试自动从磁盘恢复 Schema -func TestEngineAutoRecoverSchema(t *testing.T) { +// TestTableAutoRecoverSchema 测试自动从磁盘恢复 Schema +func TestTableAutoRecoverSchema(t *testing.T) { dir := "test_auto_recover_schema" os.RemoveAll(dir) defer os.RemoveAll(dir) @@ -1292,10 +1406,10 @@ func TestEngineAutoRecoverSchema(t *testing.T) { }) // 1. 创建引擎并提供 Schema(会保存到磁盘) - engine1, err := OpenEngine(&EngineOptions{ + table1, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, - Schema: s, + Name: s.Name, Fields: s.Fields, }) if err != nil { t.Fatal(err) @@ -1307,16 +1421,16 @@ func TestEngineAutoRecoverSchema(t *testing.T) { "name": fmt.Sprintf("user_%d", i), "age": 20 + i, } - err := engine1.Insert(data) + err := table1.Insert(data) if err != nil { t.Fatalf("Failed to insert: %v", err) } } - engine1.Close() + table1.Close() // 2. 重新打开引擎,不提供 Schema(应该自动从磁盘恢复) - engine2, err := OpenEngine(&EngineOptions{ + table2, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, // 不设置 Schema @@ -1326,7 +1440,7 @@ func TestEngineAutoRecoverSchema(t *testing.T) { } // 验证 Schema 已恢复 - recoveredSchema := engine2.GetSchema() + recoveredSchema := table2.GetSchema() if recoveredSchema == nil { t.Fatal("Expected schema to be recovered, but got nil") } @@ -1340,7 +1454,7 @@ func TestEngineAutoRecoverSchema(t *testing.T) { } // 验证数据 - row, err := engine2.Get(1) + row, err := table2.Get(1) if err != nil { t.Fatalf("Failed to get row: %v", err) } @@ -1349,7 +1463,7 @@ func TestEngineAutoRecoverSchema(t *testing.T) { } // 尝试插入新数据(应该符合恢复的 Schema) - err = engine2.Insert(map[string]any{ + err = table2.Insert(map[string]any{ "name": "new_user", "age": 30, }) @@ -1358,7 +1472,7 @@ func TestEngineAutoRecoverSchema(t *testing.T) { } // 尝试插入不符合 Schema 的数据(应该失败) - err = engine2.Insert(map[string]any{ + err = table2.Insert(map[string]any{ "name": "bad_user", "age": "invalid", // 类型错误 }) @@ -1366,13 +1480,13 @@ func TestEngineAutoRecoverSchema(t *testing.T) { t.Fatal("Expected insert to fail with invalid type, but it succeeded") } - engine2.Close() + table2.Close() t.Log("Auto recover schema test passed!") } -// TestEngineSchemaTamperDetection 测试篡改检测 -func TestEngineSchemaTamperDetection(t *testing.T) { +// TestTableSchemaTamperDetection 测试篡改检测 +func TestTableSchemaTamperDetection(t *testing.T) { dir := "test_schema_tamper" os.RemoveAll(dir) defer os.RemoveAll(dir) @@ -1384,15 +1498,15 @@ func TestEngineSchemaTamperDetection(t *testing.T) { }) // 1. 创建引擎并保存 Schema - engine1, err := OpenEngine(&EngineOptions{ + table1, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, - Schema: s, + Name: s.Name, Fields: s.Fields, }) if err != nil { t.Fatal(err) } - engine1.Close() + table1.Close() // 2. 篡改 schema.json(修改字段但不更新 checksum) schemaPath := fmt.Sprintf("%s/schema.json", dir) @@ -1410,12 +1524,12 @@ func TestEngineSchemaTamperDetection(t *testing.T) { } // 3. 尝试打开引擎,应该检测到篡改 - engine2, err := OpenEngine(&EngineOptions{ + table2, err := OpenTable(&TableOptions{ Dir: dir, MemTableSize: 10 * 1024 * 1024, }) if err == nil { - engine2.Close() + table2.Close() t.Fatal("Expected to detect schema tampering, but open succeeded") } @@ -1428,3 +1542,305 @@ func TestEngineSchemaTamperDetection(t *testing.T) { t.Logf("Detected tampering as expected: %v", err) t.Log("Schema tamper detection test passed!") } + +func TestTableClean(t *testing.T) { + dir := "./test_table_clean_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("users", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"}, + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, + }) + + table, err := db.CreateTable("users", schema) + if err != nil { + t.Fatal(err) + } + + // 2. 插入数据 + for i := range 100 { + err := table.Insert(map[string]any{ + "id": int64(i), + "name": "user" + string(rune(i)), + }) + if err != nil { + t.Fatal(err) + } + } + + // 3. 验证数据存在 + stats := table.Stats() + t.Logf("Before Clean: %d rows", stats.TotalRows) + + if stats.TotalRows == 0 { + t.Error("Expected data in table") + } + + // 4. 清除数据 + err = table.Clean() + if err != nil { + t.Fatal(err) + } + + // 5. 验证数据已清除 + stats = table.Stats() + t.Logf("After Clean: %d rows", stats.TotalRows) + + if stats.TotalRows != 0 { + t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) + } + + // 6. 验证表仍然可用 + err = table.Insert(map[string]any{ + "id": int64(100), + "name": "new_user", + }) + if err != nil { + t.Fatal(err) + } + + stats = table.Stats() + if stats.TotalRows != 1 { + t.Errorf("Expected 1 row after insert, got %d", stats.TotalRows) + } +} + +func TestTableDestroy(t *testing.T) { + dir := "./test_table_destroy_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + }) + + table, err := db.CreateTable("test", schema) + if err != nil { + t.Fatal(err) + } + + // 2. 插入数据 + for i := range 50 { + table.Insert(map[string]any{"id": int64(i)}) + } + + // 3. 验证数据存在 + stats := table.Stats() + t.Logf("Before Destroy: %d rows", stats.TotalRows) + + if stats.TotalRows == 0 { + t.Error("Expected data in table") + } + + // 4. 获取表目录路径 + tableDir := table.dir + + // 5. 销毁表 + err = table.Destroy() + if err != nil { + t.Fatal(err) + } + + // 6. 验证表目录已删除 + if _, err := os.Stat(tableDir); !os.IsNotExist(err) { + t.Error("Table directory should be deleted") + } + + // 7. 注意:Table.Destroy() 只删除文件,不从 Database 中删除 + // 表仍然在 Database 的元数据中,但文件已被删除 + tables := db.ListTables() + found := slices.Contains(tables, "test") + if !found { + t.Error("Table should still be in database metadata (use Database.DestroyTable to remove from metadata)") + } +} + +func TestTableCleanWithIndex(t *testing.T) { + dir := "./test_table_clean_index_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("users", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"}, + {Name: "email", Type: FieldTypeString, Indexed: true, Comment: "Email"}, + {Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"}, + }) + + table, err := db.CreateTable("users", schema) + if err != nil { + t.Fatal(err) + } + + // 2. 创建索引 + err = table.CreateIndex("id") + if err != nil { + t.Fatal(err) + } + + err = table.CreateIndex("email") + if err != nil { + t.Fatal(err) + } + + // 3. 插入数据 + for i := range 50 { + table.Insert(map[string]any{ + "id": int64(i), + "email": "user" + string(rune(i)) + "@example.com", + "name": "User " + string(rune(i)), + }) + } + + // 4. 验证索引存在 + indexes := table.ListIndexes() + if len(indexes) != 2 { + t.Errorf("Expected 2 indexes, got %d", len(indexes)) + } + + // 5. 清除数据 + err = table.Clean() + if err != nil { + t.Fatal(err) + } + + // 6. 验证数据已清除 + stats := table.Stats() + if stats.TotalRows != 0 { + t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows) + } + + // 7. 验证索引已被清除(Clean 会删除索引数据) + indexes = table.ListIndexes() + if len(indexes) != 0 { + t.Logf("Note: Indexes were cleared (expected behavior), got %d", len(indexes)) + } + + // 8. 重新创建索引 + table.CreateIndex("id") + table.CreateIndex("email") + + // 9. 验证可以继续插入数据 + err = table.Insert(map[string]any{ + "id": int64(100), + "email": "new@example.com", + "name": "New User", + }) + if err != nil { + t.Fatal(err) + } + + stats = table.Stats() + if stats.TotalRows != 1 { + t.Errorf("Expected 1 row, got %d", stats.TotalRows) + } +} + +func TestTableCleanAndQuery(t *testing.T) { + dir := "./test_table_clean_query_data" + defer os.RemoveAll(dir) + + // 1. 创建数据库和表 + db, err := Open(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + schema := NewSchema("test", []Field{ + {Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"}, + {Name: "status", Type: FieldTypeString, Indexed: false, Comment: "Status"}, + }) + + table, err := db.CreateTable("test", schema) + if err != nil { + t.Fatal(err) + } + + // 2. 插入数据 + for i := range 30 { + table.Insert(map[string]any{ + "id": int64(i), + "status": "active", + }) + } + + // 3. 查询数据 + rows, err := table.Query().Eq("status", "active").Rows() + if err != nil { + t.Fatal(err) + } + + count := 0 + for rows.Next() { + count++ + } + rows.Close() + + t.Logf("Before Clean: found %d rows", count) + if count != 30 { + t.Errorf("Expected 30 rows, got %d", count) + } + + // 4. 清除数据 + err = table.Clean() + if err != nil { + t.Fatal(err) + } + + // 5. 再次查询 + rows, err = table.Query().Eq("status", "active").Rows() + if err != nil { + t.Fatal(err) + } + + count = 0 + for rows.Next() { + count++ + } + rows.Close() + + t.Logf("After Clean: found %d rows", count) + if count != 0 { + t.Errorf("Expected 0 rows after clean, got %d", count) + } + + // 6. 插入新数据并查询 + table.Insert(map[string]any{ + "id": int64(100), + "status": "active", + }) + + rows, err = table.Query().Eq("status", "active").Rows() + if err != nil { + t.Fatal(err) + } + + count = 0 + for rows.Next() { + count++ + } + rows.Close() + + if count != 1 { + t.Errorf("Expected 1 row, got %d", count) + } +} diff --git a/webui/webui.go b/webui/webui.go index 28bba2a..0bf1dbe 100644 --- a/webui/webui.go +++ b/webui/webui.go @@ -11,7 +11,7 @@ import ( "strconv" "strings" - "code.tczkiot.com/srdb" + "code.tczkiot.com/wlw/srdb" ) //go:embed static @@ -94,7 +94,7 @@ func (ui *WebUI) handleListTables(w http.ResponseWriter, r *http.Request) { tables = append(tables, TableListItem{ Name: name, - CreatedAt: table.GetCreatedAt(), + CreatedAt: 0, // TODO: Table 不再有 createdAt 字段 Fields: fields, }) } @@ -193,8 +193,7 @@ func (ui *WebUI) handleTableManifest(w http.ResponseWriter, r *http.Request, tab return } - engine := table.GetEngine() - versionSet := engine.GetVersionSet() + versionSet := table.GetVersionSet() version := versionSet.GetCurrent() // 构建每层的信息 @@ -216,7 +215,7 @@ func (ui *WebUI) handleTableManifest(w http.ResponseWriter, r *http.Request, tab } // 获取 Compaction Manager 和 Picker - compactionMgr := engine.GetCompactionManager() + compactionMgr := table.GetCompactionManager() picker := compactionMgr.GetPicker() levels := make([]LevelInfo, 0)