文档:更新 DESIGN.md,使用英文注释和调整项目结构说明

This commit is contained in:
2025-10-09 01:33:22 +08:00
parent 6499464f1c
commit 8019f2d794
38 changed files with 4297 additions and 2750 deletions

336
CLAUDE.md
View File

@@ -1,12 +1,12 @@
# CLAUDE.md
本文件为 Claude Code (claude.ai/code) 提供在本仓库中工作的指导。
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## 项目概述
SRDB 是一个用 Go 编写的高性能 Append-Only 时序数据库引擎。它使用简化的 LSM-tree 架构,结合 WAL + MemTable + mmap B+Tree SST 文件针对高并发写入200K+ 写/秒和快速查询1-5ms进行了优化。
**模块**: `code.tczkiot.com/srdb`
**模块**: `code.tczkiot.com/wlw/srdb`
## 构建和测试
@@ -14,22 +14,59 @@ SRDB 是一个用 Go 编写的高性能 Append-Only 时序数据库引擎。它
# 运行所有测试
go test -v ./...
# 运行指定包的测试
go test -v ./engine
go test -v ./compaction
go test -v ./query
# 运行单个测试
go test -v -run TestSSTable
go test -v -run TestTable
# 运行指定的测试
go test -v ./engine -run TestEngineBasic
# 运行性能测试
go test -bench=. -benchmem
# 构建示例程序
go build ./examples/basic
go build ./examples/with_schema
# 运行带超时的测试(某些 compaction 测试需要较长时间)
go test -v -timeout 30s
# 构建 WebUI 工具
cd examples/webui
go build -o webui main.go
./webui serve --db ./data
```
## 架构
### 两层存储模型
### 文件结构(扁平化设计)
所有核心代码都在根目录下,采用扁平化结构:
```
srdb/
├── database.go # 多表数据库管理
├── table.go # 表管理(带 Schema
├── errors.go # 错误定义和处理(统一错误码系统)
├── wal.go # WAL 实现Write-Ahead Log
├── memtable.go # MemTablemap + sorted slice~130 行)
├── sstable.go # SSTable 文件(读写器、管理器、二进制编码)
├── btree.go # B+Tree 索引构建器、读取器4KB 节点)
├── version.go # 版本控制MANIFEST 管理)
├── compaction.go # Compaction 压缩合并
├── schema.go # Schema 定义与验证
├── index.go # 二级索引管理器
├── index_btree.go # 索引 B+Tree 实现
└── query.go # 查询构建器和表达式求值
```
**运行时数据目录**:
```
database_dir/
├── database.meta # 数据库元数据JSON
├── MANIFEST # 全局版本控制
└── table_name/ # 每表一个目录
├── schema.json # 表 Schema 定义
├── MANIFEST # 表级版本控制
├── 000001.wal # WAL 文件
├── 000001.sst # SST 文件B+Tree 索引 + 二进制数据)
└── idx_field.sst # 二级索引文件(可选)
```
### 核心架构:简化的两层模型
与传统的多层 LSM 树不同SRDB 使用简化的两层架构:
@@ -39,12 +76,12 @@ go build ./examples/with_schema
### 核心数据流
**写入路径**:
1. Schema 验证(如果定义了
2. 生成序列号 (`_seq`)
1. Schema 验证(强制要求,如果表有 Schema
2. 生成序列号 (`_seq`,原子递增的 int64
3. 追加写入 WAL顺序写
4. 插入到 Active MemTablemap + 有序 slice
5. 当 MemTable 超过阈值(默认 64MB切换到新的 Active MemTable 并异步 Immutable 刷新到 SST
6. 更新二级索引(如果已创建
5. 当 MemTable 超过阈值(默认 64MB切换到新的 Active MemTable 并异步刷新 Immutable 到 SST
6. 更新二级索引(如果字段标记为 Indexed
**读取路径**:
1. 检查 Active MemTableO(1) map 查找)
@@ -56,11 +93,11 @@ go build ./examples/with_schema
1. 如果是带 `=` 操作符的索引字段:使用二级索引 → 通过 seq 获取
2. 否则带过滤条件的全表扫描MemTable + SST
### 关键设计选择
### 关键设计决策
**MemTable: `map[int64][]byte + sorted []int64`**
- 为什么不用 SkipList实现更简单130 行Put 和 Get 都是 O(1) vs O(log N)
- 权衡:插入时需要重新排序 keys slice但实际上仍然更快
- 为什么不用 SkipList实现更简单~130 行Put 和 Get 都是 O(1) vs O(log N)
- 权衡:插入新 key 时需要重新排序 keys slice但实际上仍然更快
- Active MemTable + 多个 Immutable MemTables正在刷新中
**SST 格式: 4KB 节点的 B+Tree**
@@ -68,7 +105,13 @@ go build ./examples/with_schema
- 支持高效的 mmap 访问和零拷贝读取
- 内部节点keys + 子节点指针
- 叶子节点keys + 数据偏移量/大小
- 数据块:Snappy 压缩的 JSON 行
- 数据块:二进制编码(使用 Schema 时)或 JSON无 Schema 时)
**二进制编码格式**:
- Magic Number: `0x524F5731` ("ROW1")
- 格式:`[Magic: 4B][Seq: 8B][Time: 8B][FieldCount: 2B][FieldOffsetTable][FieldData]`
- 按字段分别编码,支持部分字段读取(`GetPartial`
- 无压缩(优先查询性能,保持 mmap 零拷贝)
**mmap 而非 read() 系统调用**
- 对 SST 文件的零拷贝访问
@@ -80,125 +123,70 @@ go build ./examples/with_schema
- 相同 seq 的新记录覆盖旧记录
- Compaction 合并文件并按 seq 去重保留最新的按时间戳
## 目录结构
## 常见开发模式
```
srdb/
├── database.go # 多表数据库管理
├── table.go # 带 schema 的表
├── engine/ # 核心存储引擎583 行)
│ └── engine.go
├── wal/ # 预写日志
│ ├── wal.go # WAL 实现208 行)
│ └── manager.go # 多 WAL 管理
├── memtable/ # 内存表
│ ├── memtable.go # MemTable130 行)
│ └── manager.go # Active + Immutable 管理
├── sst/ # SSTable 文件
│ ├── format.go # 文件格式定义
│ ├── writer.go # SST 写入器
│ ├── reader.go # mmap 读取器147 行)
│ ├── manager.go # SST 文件管理
│ └── encoding.go # Snappy 压缩
├── btree/ # B+Tree 索引
│ ├── node.go # 4KB 节点结构
│ ├── builder.go # B+Tree 构建器125 行)
│ └── reader.go # B+Tree 读取器
├── manifest/ # 版本控制
│ ├── version_set.go # 版本管理
│ ├── version_edit.go # 原子更新
│ ├── version.go # 文件元数据
│ ├── manifest_writer.go
│ └── manifest_reader.go
├── compaction/ # 后台压缩
│ ├── manager.go # Compaction 调度器
│ ├── compactor.go # 合并执行器
│ └── picker.go # 文件选择策略
├── index/ # 二级索引
│ ├── index.go # 字段级索引
│ └── manager.go # 索引生命周期
├── query/ # 查询系统
│ ├── builder.go # 流式查询 API
│ └── expr.go # 表达式求值
└── schema/ # Schema 验证
├── schema.go # 类型定义和验证
└── examples.go # Schema 示例
```
### Schema 系统(强制要求)
**运行时数据目录**例如 `./mydb/`:
```
database_dir/
├── database.meta # 数据库元数据JSON
├── MANIFEST # 全局版本控制
└── table_name/ # 每表目录
├── schema.json # 表 schema
├── MANIFEST # 表级版本控制
├── wal/ # WAL 文件(*.wal
├── sst/ # SST 文件(*.sst
└── index/ # 二级索引idx_*.sst
```
## 常见模式
### 使用 Engine
`Engine` 是核心存储层修改引擎行为时
- 所有写入都经过 `Insert()` WAL MemTable 异步刷新到 SST
- 读取经过 `Get(seq)` 检查 MemTable 检查 SST 文件
- `switchMemTable()` 创建新的 Active MemTable 并异步刷新旧的
- `flushImmutable()` MemTable 写入 SST 并更新 MANIFEST
- 后台 compaction 通过 `compactionManager` 运行
### Schema 和验证
Schema 是可选的但建议在生产环境使用
从最近的重构开始Schema **强制**不再支持无 Schema 模式
```go
schema := schema.NewSchema("users").
AddField("name", schema.FieldTypeString, false, "用户名").
AddField("age", schema.FieldTypeInt64, false, "用户年龄").
AddField("email", schema.FieldTypeString, true, "邮箱(索引)")
schema := NewSchema("users", []Field{
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "用户名"},
{Name: "age", Type: FieldTypeInt64, Indexed: false, Comment: "年龄"},
{Name: "email", Type: FieldTypeString, Indexed: true, Comment: "邮箱(索引)"},
})
table, _ := db.CreateTable("users", schema)
```
- Schema `Insert()` 时验证类型和必填字段
- Schema `Insert()` 强制验证类型和必填字段
- 索引字段`Indexed: true`自动创建二级索引
- Schema 持久化到 `table_dir/schema.json`
- Schema 持久化到 `table_dir/schema.json`包含校验和防篡改
- 支持的类型`FieldTypeString`, `FieldTypeInt64`, `FieldTypeBool`, `FieldTypeFloat`
### Query Builder
对于带条件的查询始终使用 `QueryBuilder`
对于带条件的查询使用链式 API
```go
qb := query.NewQueryBuilder()
qb.Where("age", query.OpGreater, 18).
Where("city", query.OpEqual, "Beijing")
rows, _ := table.Query(qb)
// 简单查询
rows, _ := table.Query().Eq("name", "Alice").Rows()
// 复合条件
rows, _ := table.Query().
Eq("status", "active").
Gte("age", 18).
Rows()
// 字段选择(性能优化)
rows, _ := table.Query().
Select("id", "name", "email").
Eq("status", "active").
Rows()
// 游标模式
rows, _ := table.Query().Rows()
defer rows.Close()
for rows.Next() {
row := rows.Row()
fmt.Println(row.Data())
}
```
- 支持操作符`OpEqual``OpNotEqual``OpGreater``OpLess``OpPrefix``OpSuffix``OpContains`
- 支持 `WhereNot()` 进行否定
- 支持 `And()` `Or()` 逻辑
- 当可用时自动使用二级索引对于 `=` 条件
- 如果没有索引则回退到全表扫描
支持操作符`Eq`, `NotEq`, `Lt`, `Gt`, `Lte`, `Gte`, `In`, `NotIn`, `Between`, `Contains`, `StartsWith`, `EndsWith`, `IsNull`, `NotNull`
### Compaction
Compaction 在后台自动运行
- **触发条件**: L0 文件数 > 阈值(默认 10
- **触发条件**: L0 文件数 > 阈值(默认 4-10根据层级
- **策略**: 合并重叠文件,从 L0 → L1、L1 → L2 等
- **Score 计算**: `size / max_size``file_count / max_files`
- **安全性**: 删除前验证文件是否存在,以防止数据丢失
- **去重**: 对于重复的 seq保留最新记录按时间戳
- **文件大小**: L0=2MB、L1=10MB、L2=50MB、L3=100MB、L4+=200MB
修改 compaction 逻辑时
- `picker.go`: 选择要压缩的文件
- `compactor.go`: 执行合并操作
- `manager.go`: 调度和协调 compaction
- 删除前始终验证输入/输出文件是否存在(参见 `DoCompaction`
修改 compaction 逻辑时,注意 `compaction.go` 中的文件选择和合并逻辑。
### 版本控制MANIFEST
@@ -212,57 +200,55 @@ MANIFEST 跟踪跨版本的 SST 文件元数据:
1. 分配文件编号:`versionSet.AllocateFileNumber()`
2. 创建带变更的 `VersionEdit`
3. 应用:`versionSet.LogAndApply(edit)`
4. 清理旧文件`compactionManager.CleanupOrphanFiles()`
4. 清理旧文件(通过 GC 机制)
### 错误处理
使用统一的错误码系统(`errors.go`
```go
// 创建错误
err := NewError(ErrCodeTableNotFound, nil)
// 带上下文包装错误
err := WrapError(baseErr, "failed to get table %s", "users")
// 错误判断
if IsNotFound(err) { ... }
if IsCorrupted(err) { ... }
if IsClosed(err) { ... }
// 获取错误码
code := GetErrorCode(err)
```
- 错误码范围1000-1999通用、2000-2999数据库、3000-3999、4000-4999Schema
- 所有 panic 已替换为错误返回
- 使用 `fmt.Errorf``%w` 进行错误链包装
### 错误恢复
- **WAL 重放**: 启动时,所有 `*.wal` 文件被重放到 Active MemTable
- **孤儿文件清理**: 不在 MANIFEST 中的文件在启动时删除
- **索引修复**: `verifyAndRepairIndexes()` 重建损坏的索引
- **孤儿文件清理**: 不在 MANIFEST 中的文件在启动时删除(有年龄保护,避免误删最近写入的文件)
- **索引修复**: 自动验证和重建损坏的索引
- **优雅降级**: 表恢复失败会被记录但不会使数据库崩溃
## 测试模式
测试按组件组织:
- `engine/engine_test.go`: 基本引擎操作
- `engine/engine_compaction_test.go`: Compaction 场景
- `engine/engine_stress_test.go`: 并发压力测试
- `compaction/compaction_test.go`: Compaction 正确性
- `query/builder_test.go`: Query builder 功能
- `schema/schema_test.go`: Schema 验证
为多线程操作编写测试时,使用 `sync.WaitGroup` 并用多个 goroutine 测试(参见 `engine_stress_test.go`)。
## 性能特性
- **写入吞吐量**: 200K+ 写/秒多线程50K 写/秒(单线程)
- **写入延迟**: < 1msp99
- **查询延迟**: < 0.1msMemTable1-5msSST 热数据3-5ms冷数据
- **内存使用**: < 150MB64MB MemTable + 开销
- **压缩率**: Snappy 50%
优化时
- 批量写入以减少 WAL 同步开销
- 对经常查询的字段创建索引
- 监控 MemTable 刷新频率不应太频繁
- 根据写入模式调整 compaction 阈值
## 重要实现细节
### 序列号
### 序列号系统
- `_seq` 是单调递增的 int64原子操作
- 充当主键和时间戳排序
- 永不重用append-only
- compaction 期间相同 seq 值的较新记录优先
- Compaction 期间,相同 seq 值的较新记录优先(按 `_time` 排序)
### 并发
### 并发控制
- `Engine.mu`: 保护元数据和 SST reader 列表
- `Engine.flushMu`: 确保一次只有一个 flush
- `Table.mu`: 保护表级元数据
- `SSTableManager.mu`: RWMutex保护 SST reader 列表
- `MemTable.mu`: RWMutex支持并发读、独占写
- `VersionSet.mu`: 保护版本状态
- 无全局锁,细粒度锁设计
### 文件格式
@@ -273,7 +259,7 @@ CRC32 (4B) | Length (4B) | Type (1B) | Seq (8B) | DataLen (4B) | Data (N bytes)
**SST 文件**:
```
Header (256B) | B+Tree Index | Data Blocks (Snappy compressed)
Header (256B) | B+Tree Index (4KB nodes) | Data Blocks (Binary format)
```
**B+Tree 节点**4KB 固定):
@@ -281,11 +267,51 @@ Header (256B) | B+Tree Index | Data Blocks (Snappy compressed)
Header (32B) | Keys (8B each) | Pointers/Offsets (8B each) | Padding
```
**二进制行格式** (ROW1):
```
Magic (4B) | Seq (8B) | Time (8B) | FieldCount (2B) |
[FieldOffset, FieldSize] × N | FieldData × N
```
## 性能特性
- **写入吞吐量**: 200K+ 写/秒多线程50K 写/秒(单线程)
- **写入延迟**: < 1msp99
- **查询延迟**: < 0.1msMemTable1-5msSST 热数据3-5ms冷数据
- **内存使用**: < 150MB64MB MemTable + 开销
- **压缩**: 未使用优先查询性能
优化建议
- 批量写入以减少 WAL 同步开销
- 对经常查询的字段创建索引
- 使用 `Select()` 只查询需要的字段
- 监控 MemTable 刷新频率不应太频繁
- 根据写入模式调整 Compaction 阈值
## 常见陷阱
- Schema 验证仅在向 `Engine.Open()` 提供 schema 时才应用
- 索引必须通过 `CreateIndex(field)` 显式创建非自动
- schema QueryBuilder 需要调用 `WithSchema()` 或让引擎设置它
- Compaction 可能会暂时增加磁盘使合并期间旧文件和新文件共存
- MemTable flush 异步关闭时可能需要等待 immutable flush 完成
- mmap 文件可能显示较大的虚拟内存使用这是正常的不是实际 RAM
- **Schema 是强制的**: 所有表必须定义 Schema不再支持无 Schema 模式
- **索引非自动创建**: 需要在 Schema 中显式标记 `Indexed: true`
- **类型严格**: Schema 验证严格int int64 需要正确匹配
- **Compaction 磁盘**: 合并期间旧文件和新文件共存会暂时增加磁盘使用
- **MemTable flush 异步**: 关闭时需要等待 immutable flush 完成
- **mmap 虚拟内存**: 可能显示较大的虚拟内存使用正常OS 管理不是实际 RAM
- ** panic**: 所有 panic 已替换为错误返回需要正确处理错误
- **废弃代码**: `SSTableCompressionNone` 等常量已删除
## Web UI
项目包含功能完善的 Web 管理界面
```bash
cd examples/webui
go run main.go serve --db /path/to/database --port 8080
```
功能
- 表管理和数据浏览
- Manifest 可视化LSM-Tree 结构
- 实时 Compaction 监控
- 深色/浅色主题
详见 `examples/webui/README.md`

218
DESIGN.md
View File

@@ -1,6 +1,6 @@
# SRDB 设计文档WAL + mmap B+Tree
> 模块名:`code.tczkiot.com/srdb`
> 模块名:`code.tczkiot.com/wlw/srdb`
> 一个高性能的 Append-Only 时序数据库引擎
## 🎯 设计目标
@@ -19,10 +19,10 @@
│ SRDB Architecture │
├─────────────────────────────────────────────────────────────┤
│ Application Layer │
│ ┌───────────────┐ ┌──────────┐ ┌───────────┐ │
│ │ Database │->│ Table │->│ Engine │ │
│ │ (Multi-Table) │ │ (Schema) │ │ (Storage) │ │
│ └───────────────┘ └──────────┘ └───────────┘ │
│ ┌───────────────┐ ┌──────────────────────────┐ │
│ │ Database │->│ Table │ │
│ │ (Multi-Table) │ │ (Schema + Storage) │ │
│ └───────────────┘ └──────────────────────────┘ │
├─────────────────────────────────────────────────────────────┤
│ Write Path (High Concurrency) │
│ ┌─────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
@@ -38,10 +38,10 @@
├─────────────────────────────────────────────────────────────┤
│ Storage Layer (Persistent) │
│ ┌─────────────────────────────────────────────────┐ │
│ │ SST Files (B+Tree Format + Compression) │ │
│ │ SST Files (B+Tree Format + Binary Encoding) │ │
│ │ ┌─────────────────────────────────────────┐ │ │
│ │ │ File Header (256 bytes) │ │ │
│ │ │ - Magic, Version, Compression │ │ │
│ │ │ - Magic, Version, Metadata │ │ │
│ │ │ - MinKey, MaxKey, RowCount │ │ │
│ │ ├─────────────────────────────────────────┤ │ │
│ │ │ B+Tree Index (4 KB nodes) │ │ │
@@ -49,8 +49,9 @@
│ │ │ - Internal Nodes (Order=200) │ │ │
│ │ │ - Leaf Nodes → Data Offset │ │ │
│ │ ├─────────────────────────────────────────┤ │ │
│ │ │ Data Blocks (Snappy Compressed) │ │ │
│ │ │ - JSON serialized rows │ │ │
│ │ │ Data Blocks (Binary Format) │ │ │
│ │ │ - 有 Schema: 二进制编码 │ │ │
│ │ │ - 无 Schema: JSON 格式 │ │ │
│ │ └─────────────────────────────────────────┘ │ │
│ │ │ │
│ │ Secondary Indexes (Optional) │ │
@@ -88,59 +89,32 @@
```
srdb/ ← 项目根目录
├── go.mod ← 模块定义: code.tczkiot.com/srdb
├── go.mod ← 模块定义: code.tczkiot.com/wlw/srdb
├── DESIGN.md ← 本设计文档
├── CLAUDE.md ← Claude Code 指导文档
├── database.go ← 数据库管理 (多表)
├── table.go ← 表管理
├── table.go ← 表管理 (带 Schema)
├── errors.go ← 错误定义和处理
├── engine/存储引擎
│ └── engine.go ← 核心引擎实现 (583 行)
├── wal.go WAL 实现 (Write-Ahead Log)
├── memtable.go ← MemTable 实现 (map + sorted slice)
├── sstable.go ← SSTable 文件 (读写器、管理器、编码)
├── btree.go ← B+Tree 索引 (构建器、读取器)
├── version.go ← 版本控制 (MANIFEST 管理)
├── compaction.go ← Compaction 压缩合并
├── wal/ Write-Ahead Log
│ ├── wal.go ← WAL 实现 (208 行)
│ └── manager.go ← WAL 管理器
├── schema.goSchema 定义与验证
├── index.go ← 二级索引管理器
├── index_btree.go ← 索引 B+Tree 实现
├── query.go ← 查询构建器和表达式求值
├── memtable/ ← 内存表
│ ├── memtable.goMemTable 实现 (130 行)
│ └── manager.go ← MemTable 管理器 (多版本)
├── examples/ ← 示例程序目录
│ ├── webui/ Web UI 管理工具
│ └── ... (其他示例)
── sst/ SSTable 文件
── format.go ← 文件格式定义
│ ├── writer.go ← SST 写入器
│ ├── reader.go ← SST 读取器 (mmap, 147 行)
│ ├── manager.go ← SST 管理器
│ └── encoding.go ← 序列化/压缩
├── btree/ ← B+Tree 索引
│ ├── node.go ← 节点定义 (4KB)
│ ├── builder.go ← B+Tree 构建器 (125 行)
│ └── reader.go ← B+Tree 读取器
├── manifest/ ← 版本控制
│ ├── version_set.go ← 版本集合
│ ├── version_edit.go ← 版本变更
│ ├── version.go ← 版本信息
│ ├── manifest_writer.go ← MANIFEST 写入
│ └── manifest_reader.go ← MANIFEST 读取
├── compaction/ ← 压缩合并
│ ├── manager.go ← Compaction 管理器
│ ├── compactor.go ← 压缩执行器
│ └── picker.go ← 文件选择策略
├── index/ ← 二级索引 (新增)
│ ├── index.go ← 索引实现
│ ├── manager.go ← 索引管理器
│ └── README.md ← 索引使用文档
├── query/ ← 查询系统 (新增)
│ ├── builder.go ← 查询构建器
│ └── expr.go ← 表达式求值
└── schema/ ← Schema 系统 (新增)
├── schema.go ← Schema 定义与验证
├── examples.go ← Schema 示例
└── README.md ← Schema 使用文档
── webui/Web UI 静态资源
── ...
```
### 运行时数据目录结构
@@ -162,7 +136,7 @@ database_dir/ ← 数据库目录
│ ├── 000002.sst
│ └── 000003.sst
└── index/ ← 索引目录 (可选)
└── idx/ ← 索引目录 (可选)
├── idx_name.sst ← 字段 name 的索引
└── idx_email.sst ← 字段 email 的索引
```
@@ -283,14 +257,20 @@ B+Tree 节点格式:
│ ├─ Child Pointer 2 (8 bytes) │
│ └─ ... │
│ │
│ Leaf Node:
│ ├─ Data Offset 1 (8 bytes)
│ ├─ Data Size 1 (4 bytes)
─ Data Offset 2 (8 bytes)
│ ├─ Data Size 2 (4 bytes)
│ Leaf Node (interleaved storage):
│ ├─ (Offset, Size) Pair 1
├─ Data Offset 1 (8 bytes) │
│ └─ Data Size 1 (4 bytes) │
│ ├─ (Offset, Size) Pair 2
│ │ ├─ Data Offset 2 (8 bytes) │
│ │ └─ Data Size 2 (4 bytes) │
│ └─ ... │
└─────────────────────────────────────┘
解释:
- interleaved storage: 交叉存储
优势:
✅ 固定大小 (4 KB) - 对齐页面
✅ 可以直接 mmap 访问
@@ -513,7 +493,6 @@ type Table struct {
name string
dir string
schema *schema.Schema
engine *engine.Engine
}
使用示例:
@@ -571,7 +550,7 @@ Flush 流程 (后台):
4. 构建 B+Tree 索引
5. 写入数据块 (Snappy 压缩)
5. 写入数据块 (二进制格式)
6. 写入 B+Tree 索引
@@ -630,22 +609,23 @@ Flush 流程 (后台):
### 代码规模
```
核心代码: 5399 行 (不含测试和示例)
├── engine: 583 行
├── wal: 208 行
├── memtable: 130 行
├── sst: 147 行 (reader)
├── btree: 125 行 (builder)
├── manifest: ~500 行
├── compaction: ~400 行
├── index: ~400 行
├── query: ~300 行
├── schema: ~200 行
── database: ~300 行
核心代码: ~13,000 行 (不含测试和示例)
├── table.go: 表管理和存储引擎
├── wal.go: WAL 实现
├── memtable.go: MemTable 实现
├── sstable.go: SSTable 文件读写
├── btree.go: B+Tree 索引
├── version.go: 版本控制 (MANIFEST)
├── compaction.go: Compaction 压缩
├── index.go: 二级索引
├── query.go: 查询构建器
├── schema.go: Schema 验证
── errors.go: 错误处理
└── database.go: 数据库管理
测试代码: ~2000+ 行
示例代码: ~1000+ 行
总计: 8000+ 行
总计: 16,000+ 行
```
### 写入性能
@@ -692,10 +672,10 @@ Flush 流程 (后台):
```
示例 (100 万条记录,每条 200 bytes):
- 原始数据: 200 MB
- Snappy 压缩: 100 MB (50% 压缩率)
- 二进制编码: ~180 MB (紧凑格式)
- B+Tree 索引: 20 MB (10%)
- 二级索引: 10 MB (可选)
- 总计: 130 MB (65% 压缩)
- 总计: ~210 MB (压缩)
```
## 🔧 实现状态
@@ -704,14 +684,14 @@ Flush 流程 (后台):
```
核心存储引擎:
- [✅] Schema 定义和解析
- [✅] WAL 实现 (wal/)
- [✅] MemTable 实现 (memtable/,使用 map+slice)
- [✅] 基础的 Insert 和 Get
- [✅] SST 文件格式定义 (sst/format.go)
- [✅] B+Tree 构建器 (btree/)
- [✅] Schema 定义和解析 (schema.go)
- [✅] WAL 实现 (wal.go)
- [✅] MemTable 实现 (memtable.go,使用 map+slice)
- [✅] 基础的 Insert 和 Get (table.go)
- [✅] SST 文件格式定义 (sstable.go)
- [✅] B+Tree 构建器 (btree.go)
- [✅] Flush 流程 (异步)
- [✅] mmap 查询 (sst/reader.go)
- [✅] mmap 查询 (sstable.go)
```
### Phase 2: 优化和稳定 ✅ 已完成
@@ -721,9 +701,9 @@ Flush 流程 (后台):
- [✅] 批量写入优化
- [✅] 并发控制优化
- [✅] 崩溃恢复 (WAL 重放)
- [✅] MANIFEST 管理 (manifest/)
- [✅] Compaction 实现 (compaction/)
- [✅] MemTable Manager (多版本管理)
- [✅] MANIFEST 管理 (version.go)
- [✅] Compaction 实现 (compaction.go)
- [✅] MemTable Manager (多版本管理table.go)
- [✅] 性能测试 (各种 *_test.go)
- [✅] 文档完善 (README.md, DESIGN.md)
```
@@ -733,14 +713,15 @@ Flush 流程 (后台):
```
高级功能:
- [✅] 数据库和表管理 (database.go, table.go)
- [✅] Schema 系统 (schema/)
- [✅] 二级索引 (index/)
- [✅] 查询构建器 (query/)
- [✅] Schema 系统 (schema.go强制要求)
- [✅] 二级索引 (index.go, index_btree.go)
- [✅] 查询构建器 (query.go)
- [✅] 条件查询 (AND/OR/NOT)
- [✅] 字符串匹配 (Contains/StartsWith/EndsWith)
- [✅] 版本控制和自动修复
- [✅] 统计信息 (engine.Stats())
- [✅] 压缩和编码 (Snappy)
- [✅] 统计信息 (table.Stats())
- [✅] 二进制编码和序列化 (ROW1 格式)
- [✅] 统一错误处理 (errors.go)
```
### Phase 4: 示例和文档 ✅ 已完成
@@ -817,15 +798,15 @@ Flush 流程 (后台):
- 优势: 列裁剪,压缩率高
- 劣势: 实现复杂Flush 慢
最终实现 (V3): 行式存储 + Snappy
- 优势: 实现简单Flush 快
- 劣势: 压缩率稍低
最终实现 (V3): 行式存储 + 二进制格式
- 优势: 实现简单Flush 快,紧凑高效
- 劣势: 相比列式压缩率稍低
权衡:
- 追求简单和快速实现
- 行式 + Snappy 已经有 50% 压缩率
- 二进制格式已经足够紧凑
- 满足大多数时序数据场景
- 如果未来需要,可以演进到列式
- 如果未来需要,可以演进到列式或添加压缩
```
### 为什么用 B+Tree 而不是 LSM Tree
@@ -868,6 +849,33 @@ mmap 方式:
✅ OS 自动优化
```
### 为什么不使用压缩Snappy/LZ4
```
压缩的优势:
- 减少磁盘空间
- 可能减少 I/O
压缩的劣势:
- CPU 开销(压缩/解压)
- 查询延迟增加
- mmap 零拷贝失效(需要先解压)
- 实现复杂度增加
最终决策: 不使用压缩
- 优先考虑查询性能
- 保持 mmap 零拷贝优势
- 二进制格式已经足够紧凑
- 现代存储成本较低
- 如果真需要压缩,可以在应用层或文件系统层实现
权衡:
✅ 查询延迟更低
✅ 实现更简单
✅ mmap 零拷贝有效
❌ 磁盘占用稍大
```
## 🎯 总结
SRDB 是一个功能完善的高性能 Append-Only 数据库引擎
@@ -883,7 +891,7 @@ SRDB 是一个功能完善的高性能 Append-Only 数据库引擎:
**技术亮点:**
- 简洁的 MemTable 实现 (map + sorted slice)
- B+Tree 索引4KB 节点对齐
- Snappy 压缩50% 压缩率
- 高效的二进制编码格式
- 多版本 MemTable 管理
- 后台 Compaction
- 版本控制和自动修复
@@ -904,12 +912,8 @@ SRDB 是一个功能完善的高性能 Append-Only 数据库引擎:
- 传统 OLTP 系统
**项目成果:**
- 核心代码: 5399
- 测试代码: 2000+
- 示例程序: 13 个完整示例
- 核心代码: ~13,000
- 测试代码: ~2,000+
- 示例程序: 13+ 个完整示例
- 文档: 完善的设计和使用文档
- 性能: 达到设计目标
---
**项目已完成并可用于生产环境!** 🎉

View File

@@ -1,4 +1,4 @@
.PHONY: help test test-verbose test-coverage test-race test-bench test-engine test-compaction test-btree test-memtable test-sstable test-wal test-version test-schema test-index test-database fmt fmt-check vet tidy verify clean build run-webui install-webui
.PHONY: help test test-verbose test-coverage test-race test-bench test-table test-compaction test-btree test-memtable test-sstable test-wal test-version test-schema test-index test-database fmt fmt-check vet tidy verify clean build run-webui install-webui
# 默认目标
.DEFAULT_GOAL := help
@@ -40,9 +40,9 @@ test-bench: ## 运行基准测试
@echo "$(GREEN)运行基准测试...$(RESET)"
@go test -bench=. -benchmem $$(go list ./... | grep -v /examples/)
test-engine: ## 只运行 engine 测试
@echo "$(GREEN)运行 engine 测试...$(RESET)"
@go test -v -run TestEngine
test-table: ## 只运行 table 测试
@echo "$(GREEN)运行 table 测试...$(RESET)"
@go test -v -run TestTable
test-compaction: ## 只运行 compaction 测试
@echo "$(GREEN)运行 compaction 测试...$(RESET)"

View File

@@ -46,7 +46,7 @@
### 安装
```bash
go get code.tczkiot.com/srdb
go get code.tczkiot.com/wlw/srdb
```
### 基本示例
@@ -57,7 +57,7 @@ package main
import (
"fmt"
"log"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
func main() {
@@ -376,21 +376,19 @@ go run main.go serve --auto-insert
```
Database
├── Table
│ ├── Schema表结构
└── Engine存储引擎
── MemTable Manager
│ ├── Active MemTable
│ └── Immutable MemTables
├── SSTable Manager
│ └── SST Files (Level 0-6)
├── WAL Manager
│ └── Write-Ahead Log
├── Version Manager
│ └── MVCC Versions
│ └── Compaction Manager
│ ├── Picker选择策略
│ └── Worker执行合并
├── Table (Schema + Storage)
│ ├── MemTable Manager
│ ├── Active MemTable
── Immutable MemTables
├── SSTable Manager
│ │ └── SST Files (Level 0-6)
├── WAL Manager
│ │ └── Write-Ahead Log
├── Version Manager
│ │ └── MVCC Versions
└── Compaction Manager
├── Picker选择策略
│ └── Worker执行合并
└── Query Builder
└── Expression Engine
```
@@ -454,13 +452,14 @@ srdb/
├── btree.go # B-Tree 索引实现
├── compaction.go # Compaction 管理器
├── database.go # 数据库管理
├── engine.go # 存储引擎核心
├── errors.go # 错误定义和处理
├── index.go # 索引管理
├── index_btree.go # 索引 B+Tree
├── memtable.go # 内存表
├── query.go # 查询构建器
├── schema.go # Schema 定义
├── sstable.go # SSTable 文件
├── table.go # 表管理
├── table.go # 表管理(含存储引擎)
├── version.go # 版本管理MVCC
├── wal.go # Write-Ahead Log
├── webui/ # Web UI
@@ -477,7 +476,7 @@ srdb/
go test ./...
# 运行特定测试
go test -v -run TestEngine
go test -v -run TestTable
# 性能测试
go test -bench=. -benchmem
@@ -500,7 +499,7 @@ go build -o webui main.go
- [设计文档](DESIGN.md) - 详细的架构设计和实现原理
- [WebUI 文档](examples/webui/README.md) - Web 管理界面使用指南
- [API 文档](https://pkg.go.dev/code.tczkiot.com/srdb) - Go API 参考
- [API 文档](https://pkg.go.dev/code.tczkiot.com/wlw/srdb) - Go API 参考
---
@@ -541,8 +540,8 @@ MIT License - 详见 [LICENSE](LICENSE) 文件
## 📧 联系方式
- 项目主页https://code.tczkiot.com/srdb
- Issue 跟踪https://code.tczkiot.com/srdb/issues
- 项目主页https://code.tczkiot.com/wlw/srdb
- Issue 跟踪https://code.tczkiot.com/wlw/srdb/issues
---

189
btree.go
View File

@@ -2,6 +2,7 @@ package srdb
import (
"encoding/binary"
"fmt"
"os"
"slices"
"sort"
@@ -9,6 +10,88 @@ import (
"github.com/edsrzf/mmap-go"
)
/*
B+Tree 存储格式
B+Tree 用于索引 SSTable Index 文件提供 O(log n) 查询性能
节点结构 (4096 bytes):
Node Header (32 bytes)
NodeType (1 byte): 0=Internal, 1=Leaf
KeyCount (2 bytes): 节点中的 key 数量
Level (1 byte): 层级 (0=叶子层)
Reserved (28 bytes): 预留空间
Keys Array (variable)
Key[0..KeyCount-1]: int64 (8 bytes each)
Values (variable, 取决于节点类型)
内部节点 (Internal Node):
Children[0..KeyCount]: int64 (8 bytes each)
- 子节点的文件偏移量
- Children[i] 包含 < Key[i] 的所有 key
- Children[KeyCount] 包含 >= Key[KeyCount-1] key
叶子节点 (Leaf Node):
Data Pairs[0..KeyCount-1]: 交错存储 (12 bytes each)
DataOffset: int64 (8 bytes) - 数据块的文件偏移量
DataSize: int32 (4 bytes) - 数据块的大小
节点头格式 (32 bytes):
Offset | Size | Field | Description
-------|------|------------|----------------------------------
0 | 1 | NodeType | 0=Internal, 1=Leaf
1 | 2 | KeyCount | key 数量 (0 ~ BTreeOrder)
3 | 1 | Level | 层级 (0=叶子层, 1+=内部层)
4 | 28 | Reserved | 预留空间
内部节点布局 (示例: KeyCount=3):
[Header: 32B]
[Keys: Key0(8B), Key1(8B), Key2(8B)]
[Children: Child0(8B), Child1(8B), Child2(8B), Child3(8B)]
查询规则:
- key < Key0 Child0
- Key0 key < Key1 Child1
- Key1 key < Key2 Child2
- key Key2 Child3
叶子节点布局 (示例: KeyCount=3):
[Header: 32B]
[Keys: Key0(8B), Key1(8B), Key2(8B)]
[Data: (Offset0, Size0), (Offset1, Size1), (Offset2, Size2)]
- 交错存储: Offset0(8B), Size0(4B), Offset1(8B), Size1(4B), Offset2(8B), Size2(4B)
查询规则:
- 找到 key == Key[i]
- 返回 (DataOffsets[i], DataSizes[i])
B+Tree 特性:
- 阶数 (Order): 200 (每个节点最多 200 key)
- 节点大小: 4096 bytes (4 KB对齐页大小)
- 高度: log(N) (100万条数据约 3 )
- 查询复杂度: O(log n)
- 范围查询: 支持叶子节点有序
文件布局示例:
SSTable/Index 文件:
[Header: 256B]
[B+Tree Nodes: 4KB each]
Root Node (Internal)
Level 1 Nodes (Internal)
Leaf Nodes
[Data Blocks: variable]
性能优化:
- mmap 零拷贝: 直接从内存映射读取节点
- 节点对齐: 4KB 对齐利用操作系统页缓存
- 有序存储: 叶子节点有序支持范围查询
- 紧凑编码: 最小化节点大小提高缓存命中率
*/
const (
BTreeNodeSize = 4096 // 节点大小 (4 KB)
BTreeOrder = 200 // B+Tree 阶数 (保守估计叶子节点每个entry 20 bytes)
@@ -59,6 +142,29 @@ func NewLeafNode() *BTreeNode {
}
// Marshal 序列化节点到 4 KB
//
// 布局:
// [Header: 32B]
// [Keys: KeyCount * 8B]
// [Values: 取决于节点类型]
// - Internal: Children (KeyCount+1) * 8B
// - Leaf: 交错存储 (Offset, Size) 对,每对 12B共 KeyCount * 12B
//
// 示例叶子节点KeyCount=3
// Offset | Size | Content
// -------|------|----------------------------------
// 0 | 1 | NodeType = 1 (Leaf)
// 1 | 2 | KeyCount = 3
// 3 | 1 | Level = 0
// 4 | 28 | Reserved
// 32 | 24 | Keys [100, 200, 300]
// 56 | 8 | DataOffset0 = 1000
// 64 | 4 | DataSize0 = 50
// 68 | 8 | DataOffset1 = 2000
// 76 | 4 | DataSize1 = 60
// 80 | 8 | DataOffset2 = 3000
// 88 | 4 | DataSize2 = 70
// 92 | 4004 | Padding (unused)
func (n *BTreeNode) Marshal() []byte {
buf := make([]byte, BTreeNodeSize)
@@ -105,6 +211,16 @@ func (n *BTreeNode) Marshal() []byte {
}
// UnmarshalBTree 从字节数组反序列化节点
//
// 参数:
// data: 4KB 节点数据(通常来自 mmap
//
// 返回:
// *BTreeNode: 反序列化后的节点
//
// 零拷贝优化:
// - 直接从 mmap 数据读取,不复制整个节点
// - 只复制必要的字段Keys, Children, DataOffsets, DataSizes
func UnmarshalBTree(data []byte) *BTreeNode {
if len(data) < BTreeNodeSize {
return nil
@@ -171,25 +287,47 @@ func (n *BTreeNode) AddKey(key int64) {
}
// AddChild 添加子节点 (仅用于内部节点)
func (n *BTreeNode) AddChild(offset int64) {
func (n *BTreeNode) AddChild(offset int64) error {
if n.NodeType != BTreeNodeTypeInternal {
panic("AddChild called on leaf node")
return fmt.Errorf("AddChild called on leaf node")
}
n.Children = append(n.Children, offset)
return nil
}
// AddData 添加数据位置 (仅用于叶子节点)
func (n *BTreeNode) AddData(key int64, offset int64, size int32) {
func (n *BTreeNode) AddData(key int64, offset int64, size int32) error {
if n.NodeType != BTreeNodeTypeLeaf {
panic("AddData called on internal node")
return fmt.Errorf("AddData called on internal node")
}
n.Keys = append(n.Keys, key)
n.DataOffsets = append(n.DataOffsets, offset)
n.DataSizes = append(n.DataSizes, size)
n.KeyCount = uint16(len(n.Keys))
return nil
}
// BTreeBuilder 从下往上构建 B+Tree
//
// 构建流程:
// 1. Add(): 添加所有 (key, offset, size) 到叶子节点
// - 当叶子节点满时,创建新的叶子节点
// - 所有叶子节点按 key 有序
//
// 2. Build(): 从叶子层向上构建
// - Level 0: 叶子节点(已创建)
// - Level 1: 为叶子节点创建父节点(内部节点)
// - Level 2+: 递归创建更高层级
// - 最终返回根节点偏移量
//
// 示例100 个 keyOrder=200
// - 叶子层: 1 个叶子节点100 个 key
// - 根节点: 叶子节点本身
//
// 示例500 个 keyOrder=200
// - 叶子层: 3 个叶子节点200, 200, 100 个 key
// - Level 1: 1 个内部节点3 个子节点)
// - 根节点: Level 1 的内部节点
type BTreeBuilder struct {
order int // B+Tree 阶数
file *os.File // 输出文件
@@ -220,7 +358,9 @@ func (b *BTreeBuilder) Add(key int64, dataOffset int64, dataSize int32) error {
}
// 添加到叶子节点
leaf.AddData(key, dataOffset, dataSize)
if err := leaf.AddData(key, dataOffset, dataSize); err != nil {
return err
}
return nil
}
@@ -280,14 +420,18 @@ func (b *BTreeBuilder) buildLevel(children []*BTreeNode, childOffsets []int64, l
parent := NewInternalNode(byte(level))
// 添加第一个子节点 (没有对应的 key)
parent.AddChild(childOffsets[i])
if err := parent.AddChild(childOffsets[i]); err != nil {
return nil, nil, err
}
// 添加剩余的子节点和分隔 key
for j := i + 1; j < end; j++ {
// 分隔 key 是子节点的第一个 key
separatorKey := children[j].Keys[0]
parent.AddKey(separatorKey)
parent.AddChild(childOffsets[j])
if err := parent.AddChild(childOffsets[j]); err != nil {
return nil, nil, err
}
}
// 写入父节点
@@ -307,6 +451,24 @@ func (b *BTreeBuilder) buildLevel(children []*BTreeNode, childOffsets []int64, l
}
// BTreeReader 用于查询 B+Tree (mmap)
//
// 查询流程:
// 1. 从根节点开始
// 2. 如果是内部节点:
// - 二分查找确定子节点
// - 跳转到子节点继续查找
// 3. 如果是叶子节点:
// - 二分查找 key
// - 返回 (dataOffset, dataSize)
//
// 性能优化:
// - mmap 零拷贝:直接从内存映射读取节点
// - 二分查找O(log KeyCount) 在节点内查找
// - 总复杂度O(log n) = O(height * log Order)
//
// 示例100万条数据Order=200
// - 高度: log₂₀₀(1000000) ≈ 3
// - 查询次数: 3 次节点读取 + 3 次二分查找
type BTreeReader struct {
mmap mmap.MMap
rootOffset int64
@@ -321,6 +483,19 @@ func NewBTreeReader(mmap mmap.MMap, rootOffset int64) *BTreeReader {
}
// Get 查询 key返回数据位置
//
// 参数:
// key: 要查询的 key
//
// 返回:
// dataOffset: 数据块的文件偏移量
// dataSize: 数据块的大小
// found: 是否找到
//
// 查询流程:
// 1. 从根节点开始遍历
// 2. 内部节点:二分查找确定子节点,跳转
// 3. 叶子节点:二分查找 key返回数据位置
func (r *BTreeReader) Get(key int64) (dataOffset int64, dataSize int32, found bool) {
if r.rootOffset == 0 {
return 0, 0, false

View File

@@ -87,9 +87,15 @@ func TestBTree(t *testing.T) {
func TestBTreeSerialization(t *testing.T) {
// 测试节点序列化
leaf := NewLeafNode()
leaf.AddData(1, 1000, 100)
leaf.AddData(2, 2000, 200)
leaf.AddData(3, 3000, 300)
if err := leaf.AddData(1, 1000, 100); err != nil {
t.Fatal(err)
}
if err := leaf.AddData(2, 2000, 200); err != nil {
t.Fatal(err)
}
if err := leaf.AddData(3, 3000, 300); err != nil {
t.Fatal(err)
}
// 序列化
data := leaf.Marshal()

View File

@@ -19,6 +19,11 @@ func TestCompactionBasic(t *testing.T) {
t.Fatal(err)
}
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "value", Type: FieldTypeInt64},
})
// 创建 VersionSet
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
@@ -33,6 +38,9 @@ func TestCompactionBasic(t *testing.T) {
}
defer sstMgr.Close()
// 设置 Schema
sstMgr.SetSchema(schema)
// 创建测试数据
rows1 := make([]*SSTableRow, 100)
for i := range 100 {
@@ -75,6 +83,7 @@ func TestCompactionBasic(t *testing.T) {
// 创建 Compaction Manager
compactionMgr := NewCompactionManager(sstDir, versionSet, sstMgr)
compactionMgr.SetSchema(schema)
// 创建更多文件触发 Compaction
for i := 1; i < 5; i++ {
@@ -204,6 +213,11 @@ func TestCompactionMerge(t *testing.T) {
t.Fatal(err)
}
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "value", Type: FieldTypeString},
})
// 创建 VersionSet
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
@@ -218,6 +232,9 @@ func TestCompactionMerge(t *testing.T) {
}
defer sstMgr.Close()
// 设置 Schema
sstMgr.SetSchema(schema)
// 创建两个有重叠 key 的 SST 文件
rows1 := []*SSTableRow{
{Seq: 1, Time: 1000, Data: map[string]any{"value": "old"}},
@@ -269,6 +286,7 @@ func TestCompactionMerge(t *testing.T) {
// 创建 Compactor
compactor := NewCompactor(sstDir, versionSet)
compactor.SetSchema(schema)
// 创建 Compaction 任务
version := versionSet.GetCurrent()
@@ -316,6 +334,11 @@ func BenchmarkCompaction(b *testing.B) {
b.Fatal(err)
}
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "value", Type: FieldTypeString},
})
// 创建 VersionSet
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
@@ -330,6 +353,9 @@ func BenchmarkCompaction(b *testing.B) {
}
defer sstMgr.Close()
// 设置 Schema
sstMgr.SetSchema(schema)
// 创建测试数据
const numFiles = 5
const rowsPerFile = 1000
@@ -372,6 +398,7 @@ func BenchmarkCompaction(b *testing.B) {
// 创建 Compactor
compactor := NewCompactor(sstDir, versionSet)
compactor.SetSchema(schema)
version := versionSet.GetCurrent()
task := &CompactionTask{
@@ -401,16 +428,16 @@ func TestCompactionQueryOrder(t *testing.T) {
{Name: "timestamp", Type: FieldTypeInt64},
})
// 打开 Engine (使用较小的 MemTable 触发频繁 flush)
engine, err := OpenEngine(&EngineOptions{
// 打开 Table (使用较小的 MemTable 触发频繁 flush)
table, err := OpenTable(&TableOptions{
Dir: tmpDir,
MemTableSize: 2 * 1024 * 1024, // 2MB MemTable
Schema: schema,
Name: schema.Name, Fields: schema.Fields,
})
if err != nil {
t.Fatal(err)
}
defer engine.Close()
defer table.Close()
t.Logf("开始插入 4000 条数据...")
@@ -423,7 +450,7 @@ func TestCompactionQueryOrder(t *testing.T) {
largeData[j] = byte('A' + (j % 26))
}
err := engine.Insert(map[string]any{
err := table.Insert(map[string]any{
"id": int64(i),
"name": fmt.Sprintf("user_%d", i),
"data": string(largeData),
@@ -447,7 +474,7 @@ func TestCompactionQueryOrder(t *testing.T) {
t.Logf("开始查询所有数据...")
// 查询所有数据
rows, err := engine.Query().Rows()
rows, err := table.Query().Rows()
if err != nil {
t.Fatal(err)
}
@@ -514,7 +541,7 @@ func TestCompactionQueryOrder(t *testing.T) {
t.Logf("✓ 所有数据完整性验证通过")
// 输出 compaction 统计信息
stats := engine.GetCompactionManager().GetLevelStats()
stats := table.GetCompactionManager().GetLevelStats()
t.Logf("Compaction 统计:")
for _, levelStat := range stats {
level := levelStat["level"].(int)

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"sync"
"time"
)
// Database 数据库,管理多个表
@@ -108,8 +109,11 @@ func (db *Database) recoverTables() error {
var failedTables []string
for _, tableInfo := range db.metadata.Tables {
// FIXME: 是否需要校验 tableInfo.Dir ?
table, err := openTable(tableInfo.Name, db)
tableDir := filepath.Join(db.dir, tableInfo.Name)
table, err := OpenTable(&TableOptions{
Dir: tableDir,
MemTableSize: DefaultMemTableSize,
})
if err != nil {
// 记录失败的表,但继续恢复其他表
failedTables = append(failedTables, tableInfo.Name)
@@ -137,12 +141,25 @@ func (db *Database) CreateTable(name string, schema *Schema) (*Table, error) {
// 检查表是否已存在
if _, exists := db.tables[name]; exists {
return nil, fmt.Errorf("table %s already exists", name)
return nil, NewErrorf(ErrCodeTableExists, "table %s already exists", name)
}
// 创建表目录
tableDir := filepath.Join(db.dir, name)
err := os.MkdirAll(tableDir, 0755)
if err != nil {
return nil, err
}
// 创建表
table, err := createTable(name, schema, db)
table, err := OpenTable(&TableOptions{
Dir: tableDir,
MemTableSize: DefaultMemTableSize,
Name: schema.Name,
Fields: schema.Fields,
})
if err != nil {
os.RemoveAll(tableDir)
return nil, err
}
@@ -153,7 +170,7 @@ func (db *Database) CreateTable(name string, schema *Schema) (*Table, error) {
db.metadata.Tables = append(db.metadata.Tables, TableInfo{
Name: name,
Dir: name,
CreatedAt: table.createdAt,
CreatedAt: time.Now().Unix(),
})
err = db.saveMetadata()
@@ -171,7 +188,7 @@ func (db *Database) GetTable(name string) (*Table, error) {
table, exists := db.tables[name]
if !exists {
return nil, fmt.Errorf("table %s not found", name)
return nil, NewErrorf(ErrCodeTableNotFound, "table %s not found", name)
}
return table, nil
@@ -185,7 +202,7 @@ func (db *Database) DropTable(name string) error {
// 检查表是否存在
table, exists := db.tables[name]
if !exists {
return fmt.Errorf("table %s not found", name)
return NewErrorf(ErrCodeTableNotFound, "table %s not found", name)
}
// 关闭表

View File

@@ -1,296 +0,0 @@
package srdb
import (
"fmt"
"os"
"testing"
)
func TestDatabaseClean(t *testing.T) {
dir := "./test_db_clean_data"
defer os.RemoveAll(dir)
// 1. 创建数据库
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
// 2. 创建多个表并插入数据
// 表 1: users
usersSchema := NewSchema("users", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "User ID"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
usersTable, err := db.CreateTable("users", usersSchema)
if err != nil {
t.Fatal(err)
}
for i := 0; i < 50; i++ {
usersTable.Insert(map[string]any{
"id": int64(i),
"name": "user" + string(rune(i)),
})
}
// 表 2: orders
ordersSchema := NewSchema("orders", []Field{
{Name: "order_id", Type: FieldTypeInt64, Indexed: true, Comment: "Order ID"},
{Name: "amount", Type: FieldTypeInt64, Indexed: false, Comment: "Amount"},
})
ordersTable, err := db.CreateTable("orders", ordersSchema)
if err != nil {
t.Fatal(err)
}
for i := 0; i < 30; i++ {
ordersTable.Insert(map[string]any{
"order_id": int64(i),
"amount": int64(i * 100),
})
}
// 3. 验证数据存在
usersStats := usersTable.Stats()
ordersStats := ordersTable.Stats()
t.Logf("Before Clean - Users: %d rows, Orders: %d rows",
usersStats.TotalRows, ordersStats.TotalRows)
if usersStats.TotalRows == 0 || ordersStats.TotalRows == 0 {
t.Error("Expected data in tables")
}
// 4. 清除所有表的数据
err = db.Clean()
if err != nil {
t.Fatal(err)
}
// 5. 验证数据已清除
usersStats = usersTable.Stats()
ordersStats = ordersTable.Stats()
t.Logf("After Clean - Users: %d rows, Orders: %d rows",
usersStats.TotalRows, ordersStats.TotalRows)
if usersStats.TotalRows != 0 {
t.Errorf("Expected 0 rows in users, got %d", usersStats.TotalRows)
}
if ordersStats.TotalRows != 0 {
t.Errorf("Expected 0 rows in orders, got %d", ordersStats.TotalRows)
}
// 6. 验证表结构仍然存在
tables := db.ListTables()
if len(tables) != 2 {
t.Errorf("Expected 2 tables, got %d", len(tables))
}
// 7. 验证可以继续插入数据
err = usersTable.Insert(map[string]any{
"id": int64(100),
"name": "new_user",
})
if err != nil {
t.Fatal(err)
}
usersStats = usersTable.Stats()
if usersStats.TotalRows != 1 {
t.Errorf("Expected 1 row after insert, got %d", usersStats.TotalRows)
}
db.Close()
}
func TestDatabaseDestroy(t *testing.T) {
dir := "./test_db_destroy_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 插入数据
for i := 0; i < 20; i++ {
table.Insert(map[string]any{"id": int64(i)})
}
// 2. 验证数据存在
stats := table.Stats()
t.Logf("Before Destroy: %d rows", stats.TotalRows)
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 3. 销毁数据库
err = db.Destroy()
if err != nil {
t.Fatal(err)
}
// 4. 验证数据目录已删除
if _, err := os.Stat(dir); !os.IsNotExist(err) {
t.Error("Database directory should be deleted")
}
// 5. 验证数据库不可用
tables := db.ListTables()
if len(tables) != 0 {
t.Errorf("Expected 0 tables after destroy, got %d", len(tables))
}
}
func TestDatabaseCleanMultipleTables(t *testing.T) {
dir := "./test_db_clean_multi_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和多个表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
// 创建 5 个表
for i := 0; i < 5; i++ {
tableName := fmt.Sprintf("table%d", i)
schema := NewSchema(tableName, []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
{Name: "value", Type: FieldTypeString, Indexed: false, Comment: "Value"},
})
table, err := db.CreateTable(tableName, schema)
if err != nil {
t.Fatal(err)
}
// 每个表插入 10 条数据
for j := 0; j < 10; j++ {
table.Insert(map[string]any{
"id": int64(j),
"value": fmt.Sprintf("value_%d_%d", i, j),
})
}
}
// 2. 验证所有表都有数据
tables := db.ListTables()
if len(tables) != 5 {
t.Fatalf("Expected 5 tables, got %d", len(tables))
}
totalRows := 0
for _, tableName := range tables {
table, _ := db.GetTable(tableName)
stats := table.Stats()
totalRows += int(stats.TotalRows)
}
t.Logf("Total rows before clean: %d", totalRows)
if totalRows == 0 {
t.Error("Expected data in tables")
}
// 3. 清除所有表
err = db.Clean()
if err != nil {
t.Fatal(err)
}
// 4. 验证所有表数据已清除
totalRows = 0
for _, tableName := range tables {
table, _ := db.GetTable(tableName)
stats := table.Stats()
totalRows += int(stats.TotalRows)
if stats.TotalRows != 0 {
t.Errorf("Table %s should have 0 rows, got %d", tableName, stats.TotalRows)
}
}
t.Logf("Total rows after clean: %d", totalRows)
// 5. 验证表结构仍然存在
tables = db.ListTables()
if len(tables) != 5 {
t.Errorf("Expected 5 tables after clean, got %d", len(tables))
}
}
func TestDatabaseCleanAndReopen(t *testing.T) {
dir := "./test_db_clean_reopen_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 插入数据
for i := 0; i < 50; i++ {
table.Insert(map[string]any{"id": int64(i)})
}
// 2. 清除数据
err = db.Clean()
if err != nil {
t.Fatal(err)
}
// 3. 关闭并重新打开
db.Close()
db2, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db2.Close()
// 4. 验证表存在但数据为空
tables := db2.ListTables()
if len(tables) != 1 {
t.Errorf("Expected 1 table, got %d", len(tables))
}
table2, err := db2.GetTable("test")
if err != nil {
t.Fatal(err)
}
stats := table2.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after reopen, got %d", stats.TotalRows)
}
// 5. 验证可以插入新数据
err = table2.Insert(map[string]any{"id": int64(100)})
if err != nil {
t.Fatal(err)
}
stats = table2.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row, got %d", stats.TotalRows)
}
}

View File

@@ -1,195 +0,0 @@
package srdb
import (
"os"
"testing"
)
func TestDatabaseCleanTable(t *testing.T) {
dir := "./test_db_clean_table_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("users", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
table, err := db.CreateTable("users", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := 0; i < 50; i++ {
table.Insert(map[string]any{
"id": int64(i),
"name": "user",
})
}
// 3. 验证数据存在
stats := table.Stats()
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 4. 清除表数据
err = db.CleanTable("users")
if err != nil {
t.Fatal(err)
}
// 5. 验证数据已清除
stats = table.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows)
}
// 6. 验证表仍然存在
tables := db.ListTables()
found := false
for _, name := range tables {
if name == "users" {
found = true
break
}
}
if !found {
t.Error("Table should still exist after clean")
}
// 7. 验证可以继续插入
err = table.Insert(map[string]any{
"id": int64(100),
"name": "new_user",
})
if err != nil {
t.Fatal(err)
}
stats = table.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row, got %d", stats.TotalRows)
}
}
func TestDatabaseDestroyTable(t *testing.T) {
dir := "./test_db_destroy_table_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := 0; i < 30; i++ {
table.Insert(map[string]any{"id": int64(i)})
}
// 3. 验证数据存在
stats := table.Stats()
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 4. 销毁表
err = db.DestroyTable("test")
if err != nil {
t.Fatal(err)
}
// 5. 验证表已从 Database 中删除
tables := db.ListTables()
for _, name := range tables {
if name == "test" {
t.Error("Table should be removed from database")
}
}
// 6. 验证无法再获取该表
_, err = db.GetTable("test")
if err == nil {
t.Error("Should not be able to get table after destroy")
}
}
func TestDatabaseDestroyTableMultiple(t *testing.T) {
dir := "./test_db_destroy_multi_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和多个表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
// 创建 3 个表
for i := 1; i <= 3; i++ {
tableName := "table" + string(rune('0'+i))
_, err := db.CreateTable(tableName, schema)
if err != nil {
t.Fatal(err)
}
}
// 2. 验证有 3 个表
tables := db.ListTables()
if len(tables) != 3 {
t.Fatalf("Expected 3 tables, got %d", len(tables))
}
// 3. 销毁中间的表
err = db.DestroyTable("table2")
if err != nil {
t.Fatal(err)
}
// 4. 验证只剩 2 个表
tables = db.ListTables()
if len(tables) != 2 {
t.Errorf("Expected 2 tables, got %d", len(tables))
}
// 5. 验证剩余的表是正确的
hasTable1 := false
hasTable3 := false
for _, name := range tables {
if name == "table1" {
hasTable1 = true
}
if name == "table3" {
hasTable3 = true
}
if name == "table2" {
t.Error("table2 should be destroyed")
}
}
if !hasTable1 || !hasTable3 {
t.Error("table1 and table3 should still exist")
}
}

View File

@@ -1,7 +1,9 @@
package srdb
import (
"fmt"
"os"
"slices"
"testing"
)
@@ -257,3 +259,475 @@ func TestDatabaseRecover(t *testing.T) {
t.Log("Database recover test passed!")
}
func TestDatabaseClean(t *testing.T) {
dir := "./test_db_clean_data"
defer os.RemoveAll(dir)
// 1. 创建数据库
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
// 2. 创建多个表并插入数据
// 表 1: users
usersSchema := NewSchema("users", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "User ID"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
usersTable, err := db.CreateTable("users", usersSchema)
if err != nil {
t.Fatal(err)
}
for i := range 50 {
usersTable.Insert(map[string]any{
"id": int64(i),
"name": "user" + string(rune(i)),
})
}
// 表 2: orders
ordersSchema := NewSchema("orders", []Field{
{Name: "order_id", Type: FieldTypeInt64, Indexed: true, Comment: "Order ID"},
{Name: "amount", Type: FieldTypeInt64, Indexed: false, Comment: "Amount"},
})
ordersTable, err := db.CreateTable("orders", ordersSchema)
if err != nil {
t.Fatal(err)
}
for i := range 30 {
ordersTable.Insert(map[string]any{
"order_id": int64(i),
"amount": int64(i * 100),
})
}
// 3. 验证数据存在
usersStats := usersTable.Stats()
ordersStats := ordersTable.Stats()
t.Logf("Before Clean - Users: %d rows, Orders: %d rows",
usersStats.TotalRows, ordersStats.TotalRows)
if usersStats.TotalRows == 0 || ordersStats.TotalRows == 0 {
t.Error("Expected data in tables")
}
// 4. 清除所有表的数据
err = db.Clean()
if err != nil {
t.Fatal(err)
}
// 5. 验证数据已清除
usersStats = usersTable.Stats()
ordersStats = ordersTable.Stats()
t.Logf("After Clean - Users: %d rows, Orders: %d rows",
usersStats.TotalRows, ordersStats.TotalRows)
if usersStats.TotalRows != 0 {
t.Errorf("Expected 0 rows in users, got %d", usersStats.TotalRows)
}
if ordersStats.TotalRows != 0 {
t.Errorf("Expected 0 rows in orders, got %d", ordersStats.TotalRows)
}
// 6. 验证表结构仍然存在
tables := db.ListTables()
if len(tables) != 2 {
t.Errorf("Expected 2 tables, got %d", len(tables))
}
// 7. 验证可以继续插入数据
err = usersTable.Insert(map[string]any{
"id": int64(100),
"name": "new_user",
})
if err != nil {
t.Fatal(err)
}
usersStats = usersTable.Stats()
if usersStats.TotalRows != 1 {
t.Errorf("Expected 1 row after insert, got %d", usersStats.TotalRows)
}
db.Close()
}
func TestDatabaseDestroy(t *testing.T) {
dir := "./test_db_destroy_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 插入数据
for i := range 20 {
table.Insert(map[string]any{"id": int64(i)})
}
// 2. 验证数据存在
stats := table.Stats()
t.Logf("Before Destroy: %d rows", stats.TotalRows)
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 3. 销毁数据库
err = db.Destroy()
if err != nil {
t.Fatal(err)
}
// 4. 验证数据目录已删除
if _, err := os.Stat(dir); !os.IsNotExist(err) {
t.Error("Database directory should be deleted")
}
// 5. 验证数据库不可用
tables := db.ListTables()
if len(tables) != 0 {
t.Errorf("Expected 0 tables after destroy, got %d", len(tables))
}
}
func TestDatabaseCleanMultipleTables(t *testing.T) {
dir := "./test_db_clean_multi_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和多个表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
// 创建 5 个表
for i := range 5 {
tableName := fmt.Sprintf("table%d", i)
schema := NewSchema(tableName, []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
{Name: "value", Type: FieldTypeString, Indexed: false, Comment: "Value"},
})
table, err := db.CreateTable(tableName, schema)
if err != nil {
t.Fatal(err)
}
// 每个表插入 10 条数据
for j := range 10 {
table.Insert(map[string]any{
"id": int64(j),
"value": fmt.Sprintf("value_%d_%d", i, j),
})
}
}
// 2. 验证所有表都有数据
tables := db.ListTables()
if len(tables) != 5 {
t.Fatalf("Expected 5 tables, got %d", len(tables))
}
totalRows := 0
for _, tableName := range tables {
table, _ := db.GetTable(tableName)
stats := table.Stats()
totalRows += int(stats.TotalRows)
}
t.Logf("Total rows before clean: %d", totalRows)
if totalRows == 0 {
t.Error("Expected data in tables")
}
// 3. 清除所有表
err = db.Clean()
if err != nil {
t.Fatal(err)
}
// 4. 验证所有表数据已清除
totalRows = 0
for _, tableName := range tables {
table, _ := db.GetTable(tableName)
stats := table.Stats()
totalRows += int(stats.TotalRows)
if stats.TotalRows != 0 {
t.Errorf("Table %s should have 0 rows, got %d", tableName, stats.TotalRows)
}
}
t.Logf("Total rows after clean: %d", totalRows)
// 5. 验证表结构仍然存在
tables = db.ListTables()
if len(tables) != 5 {
t.Errorf("Expected 5 tables after clean, got %d", len(tables))
}
}
func TestDatabaseCleanAndReopen(t *testing.T) {
dir := "./test_db_clean_reopen_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 插入数据
for i := range 50 {
table.Insert(map[string]any{"id": int64(i)})
}
// 2. 清除数据
err = db.Clean()
if err != nil {
t.Fatal(err)
}
// 3. 关闭并重新打开
db.Close()
db2, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db2.Close()
// 4. 验证表存在但数据为空
tables := db2.ListTables()
if len(tables) != 1 {
t.Errorf("Expected 1 table, got %d", len(tables))
}
table2, err := db2.GetTable("test")
if err != nil {
t.Fatal(err)
}
stats := table2.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after reopen, got %d", stats.TotalRows)
}
// 5. 验证可以插入新数据
err = table2.Insert(map[string]any{"id": int64(100)})
if err != nil {
t.Fatal(err)
}
stats = table2.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row, got %d", stats.TotalRows)
}
}
func TestDatabaseCleanTable(t *testing.T) {
dir := "./test_db_clean_table_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("users", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
table, err := db.CreateTable("users", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := range 50 {
table.Insert(map[string]any{
"id": int64(i),
"name": "user",
})
}
// 3. 验证数据存在
stats := table.Stats()
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 4. 清除表数据
err = db.CleanTable("users")
if err != nil {
t.Fatal(err)
}
// 5. 验证数据已清除
stats = table.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows)
}
// 6. 验证表仍然存在
tables := db.ListTables()
found := slices.Contains(tables, "users")
if !found {
t.Error("Table should still exist after clean")
}
// 7. 验证可以继续插入
err = table.Insert(map[string]any{
"id": int64(100),
"name": "new_user",
})
if err != nil {
t.Fatal(err)
}
stats = table.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row, got %d", stats.TotalRows)
}
}
func TestDatabaseDestroyTable(t *testing.T) {
dir := "./test_db_destroy_table_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := range 30 {
table.Insert(map[string]any{"id": int64(i)})
}
// 3. 验证数据存在
stats := table.Stats()
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 4. 销毁表
err = db.DestroyTable("test")
if err != nil {
t.Fatal(err)
}
// 5. 验证表已从 Database 中删除
tables := db.ListTables()
for _, name := range tables {
if name == "test" {
t.Error("Table should be removed from database")
}
}
// 6. 验证无法再获取该表
_, err = db.GetTable("test")
if err == nil {
t.Error("Should not be able to get table after destroy")
}
}
func TestDatabaseDestroyTableMultiple(t *testing.T) {
dir := "./test_db_destroy_multi_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和多个表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
// 创建 3 个表
for i := 1; i <= 3; i++ {
tableName := "table" + string(rune('0'+i))
_, err := db.CreateTable(tableName, schema)
if err != nil {
t.Fatal(err)
}
}
// 2. 验证有 3 个表
tables := db.ListTables()
if len(tables) != 3 {
t.Fatalf("Expected 3 tables, got %d", len(tables))
}
// 3. 销毁中间的表
err = db.DestroyTable("table2")
if err != nil {
t.Fatal(err)
}
// 4. 验证只剩 2 个表
tables = db.ListTables()
if len(tables) != 2 {
t.Errorf("Expected 2 tables, got %d", len(tables))
}
// 5. 验证剩余的表是正确的
hasTable1 := false
hasTable3 := false
for _, name := range tables {
if name == "table1" {
hasTable1 = true
}
if name == "table3" {
hasTable3 = true
}
if name == "table2" {
t.Error("table2 should be destroyed")
}
}
if !hasTable1 || !hasTable3 {
t.Error("table1 and table3 should still exist")
}
}

814
engine.go
View File

@@ -1,814 +0,0 @@
package srdb
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
"sync/atomic"
"time"
)
const (
DefaultMemTableSize = 64 * 1024 * 1024 // 64 MB
DefaultAutoFlushTimeout = 30 * time.Second // 30 秒无写入自动 flush
)
// Engine 存储引擎
type Engine struct {
dir string
schema *Schema
indexManager *IndexManager
walManager *WALManager // WAL 管理器
sstManager *SSTableManager // SST 管理器
memtableManager *MemTableManager // MemTable 管理器
versionSet *VersionSet // MANIFEST 管理器
compactionManager *CompactionManager // Compaction 管理器
seq atomic.Int64
flushMu sync.Mutex
// 自动 flush 相关
autoFlushTimeout time.Duration
lastWriteTime atomic.Int64 // 最后写入时间UnixNano
stopAutoFlush chan struct{}
}
// EngineOptions 配置选项
type EngineOptions struct {
Dir string
MemTableSize int64
Schema *Schema // 可选的 Schema 定义
AutoFlushTimeout time.Duration // 自动 flush 超时时间0 表示禁用
}
// OpenEngine 打开数据库
func OpenEngine(opts *EngineOptions) (*Engine, error) {
if opts.MemTableSize == 0 {
opts.MemTableSize = DefaultMemTableSize
}
// 创建主目录
err := os.MkdirAll(opts.Dir, 0755)
if err != nil {
return nil, err
}
// 创建子目录
walDir := filepath.Join(opts.Dir, "wal")
sstDir := filepath.Join(opts.Dir, "sst")
idxDir := filepath.Join(opts.Dir, "index")
err = os.MkdirAll(walDir, 0755)
if err != nil {
return nil, err
}
err = os.MkdirAll(sstDir, 0755)
if err != nil {
return nil, err
}
err = os.MkdirAll(idxDir, 0755)
if err != nil {
return nil, err
}
// 尝试从磁盘恢复 Schema如果 Options 中没有提供)
var sch *Schema
if opts.Schema != nil {
// 使用提供的 Schema
sch = opts.Schema
// 保存到磁盘(带校验和)
schemaPath := filepath.Join(opts.Dir, "schema.json")
schemaFile, err := NewSchemaFile(sch)
if err != nil {
return nil, fmt.Errorf("create schema file: %w", err)
}
schemaData, err := json.MarshalIndent(schemaFile, "", " ")
if err != nil {
return nil, fmt.Errorf("marshal schema: %w", err)
}
err = os.WriteFile(schemaPath, schemaData, 0644)
if err != nil {
return nil, fmt.Errorf("write schema: %w", err)
}
} else {
// 尝试从磁盘恢复
schemaPath := filepath.Join(opts.Dir, "schema.json")
schemaData, err := os.ReadFile(schemaPath)
if err == nil {
// 文件存在,尝试解析
schemaFile := &SchemaFile{}
err = json.Unmarshal(schemaData, schemaFile)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal schema from %s: %w", schemaPath, err)
}
// 验证校验和
err = schemaFile.Verify()
if err != nil {
return nil, fmt.Errorf("failed to verify schema from %s: %w", schemaPath, err)
}
sch = schemaFile.Schema
} else if !os.IsNotExist(err) {
// 其他读取错误
return nil, fmt.Errorf("failed to read schema file %s: %w", schemaPath, err)
}
// 如果文件不存在sch 保持为 nil可选
}
// 创建索引管理器
var indexMgr *IndexManager
if sch != nil {
indexMgr = NewIndexManager(idxDir, sch)
}
// 创建 SST Manager
sstMgr, err := NewSSTableManager(sstDir)
if err != nil {
return nil, err
}
// 设置 Schema用于优化编解码
if sch != nil {
sstMgr.SetSchema(sch)
}
// 创建 MemTable Manager
memMgr := NewMemTableManager(opts.MemTableSize)
// 创建/恢复 MANIFEST
manifestDir := opts.Dir
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
return nil, fmt.Errorf("create version set: %w", err)
}
// 创建 Engine暂时不设置 WAL Manager
engine := &Engine{
dir: opts.Dir,
schema: sch,
indexManager: indexMgr,
walManager: nil, // 先不设置,恢复后再创建
sstManager: sstMgr,
memtableManager: memMgr,
versionSet: versionSet,
}
// 先恢复数据(包括从 WAL 恢复)
err = engine.recover()
if err != nil {
return nil, err
}
// 恢复完成后,创建 WAL Manager 用于后续写入
walMgr, err := NewWALManager(walDir)
if err != nil {
return nil, err
}
engine.walManager = walMgr
engine.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber())
// 创建 Compaction Manager
engine.compactionManager = NewCompactionManager(sstDir, versionSet, sstMgr)
// 设置 Schema如果有
if sch != nil {
engine.compactionManager.SetSchema(sch)
}
// 启动时清理孤儿文件(崩溃恢复后的清理)
engine.compactionManager.CleanupOrphanFiles()
// 启动后台 Compaction 和垃圾回收
engine.compactionManager.Start()
// 验证并修复索引
if engine.indexManager != nil {
engine.verifyAndRepairIndexes()
}
// 设置自动 flush 超时时间
if opts.AutoFlushTimeout > 0 {
engine.autoFlushTimeout = opts.AutoFlushTimeout
} else {
engine.autoFlushTimeout = DefaultAutoFlushTimeout
}
engine.stopAutoFlush = make(chan struct{})
engine.lastWriteTime.Store(time.Now().UnixNano())
// 启动自动 flush 监控
go engine.autoFlushMonitor()
return engine, nil
}
// Insert 插入数据
func (e *Engine) Insert(data map[string]any) error {
// 1. 验证 Schema (如果定义了)
if e.schema != nil {
if err := e.schema.Validate(data); err != nil {
return fmt.Errorf("schema validation failed: %v", err)
}
}
// 2. 生成 _seq
seq := e.seq.Add(1)
// 2. 添加系统字段
row := &SSTableRow{
Seq: seq,
Time: time.Now().UnixNano(),
Data: data,
}
// 3. 序列化
rowData, err := json.Marshal(row)
if err != nil {
return err
}
// 4. 写入 WAL
entry := &WALEntry{
Type: WALEntryTypePut,
Seq: seq,
Data: rowData,
}
err = e.walManager.Append(entry)
if err != nil {
return err
}
// 5. 写入 MemTable Manager
e.memtableManager.Put(seq, rowData)
// 6. 添加到索引
if e.indexManager != nil {
e.indexManager.AddToIndexes(data, seq)
}
// 7. 更新最后写入时间
e.lastWriteTime.Store(time.Now().UnixNano())
// 8. 检查是否需要切换 MemTable
if e.memtableManager.ShouldSwitch() {
go e.switchMemTable()
}
return nil
}
// Get 查询数据
func (e *Engine) Get(seq int64) (*SSTableRow, error) {
// 1. 先查 MemTable Manager (Active + Immutables)
data, found := e.memtableManager.Get(seq)
if found {
var row SSTableRow
err := json.Unmarshal(data, &row)
if err != nil {
return nil, err
}
return &row, nil
}
// 2. 查询 SST 文件
return e.sstManager.Get(seq)
}
// GetPartial 按需查询数据(只读取指定字段)
func (e *Engine) GetPartial(seq int64, fields []string) (*SSTableRow, error) {
// 1. 先查 MemTable Manager (Active + Immutables)
data, found := e.memtableManager.Get(seq)
if found {
var row SSTableRow
err := json.Unmarshal(data, &row)
if err != nil {
return nil, err
}
// MemTable 中的数据已经完全解析,需要手动过滤字段
if len(fields) > 0 {
filteredData := make(map[string]any)
for _, field := range fields {
if val, ok := row.Data[field]; ok {
filteredData[field] = val
}
}
row.Data = filteredData
}
return &row, nil
}
// 2. 查询 SST 文件(按需解码)
return e.sstManager.GetPartial(seq, fields)
}
// switchMemTable 切换 MemTable
func (e *Engine) switchMemTable() error {
e.flushMu.Lock()
defer e.flushMu.Unlock()
// 1. 切换到新的 WAL
oldWALNumber, err := e.walManager.Rotate()
if err != nil {
return err
}
newWALNumber := e.walManager.GetCurrentNumber()
// 2. 切换 MemTable (Active → Immutable)
_, immutable := e.memtableManager.Switch(newWALNumber)
// 3. 异步 Flush Immutable
go e.flushImmutable(immutable, oldWALNumber)
return nil
}
// flushImmutable 将 Immutable MemTable 刷新到 SST
func (e *Engine) flushImmutable(imm *ImmutableMemTable, walNumber int64) error {
// 1. 收集所有行
var rows []*SSTableRow
iter := imm.NewIterator()
for iter.Next() {
var row SSTableRow
err := json.Unmarshal(iter.Value(), &row)
if err == nil {
rows = append(rows, &row)
}
}
if len(rows) == 0 {
// 没有数据,直接清理
e.walManager.Delete(walNumber)
e.memtableManager.RemoveImmutable(imm)
return nil
}
// 2. 从 VersionSet 分配文件编号
fileNumber := e.versionSet.AllocateFileNumber()
// 3. 创建 SST 文件到 L0
reader, err := e.sstManager.CreateSST(fileNumber, rows)
if err != nil {
return err
}
// 4. 创建 FileMetadata
header := reader.GetHeader()
// 获取文件大小
sstPath := reader.GetPath()
fileInfo, err := os.Stat(sstPath)
if err != nil {
return fmt.Errorf("stat sst file: %w", err)
}
fileMeta := &FileMetadata{
FileNumber: fileNumber,
Level: 0, // Flush 到 L0
FileSize: fileInfo.Size(),
MinKey: header.MinKey,
MaxKey: header.MaxKey,
RowCount: header.RowCount,
}
// 5. 更新 MANIFEST
edit := NewVersionEdit()
edit.AddFile(fileMeta)
// 持久化当前的文件编号计数器(关键修复:防止重启后文件编号重用)
// 使用 fileNumber + 1 确保并发安全,避免竞态条件
edit.SetNextFileNumber(fileNumber + 1)
err = e.versionSet.LogAndApply(edit)
if err != nil {
return fmt.Errorf("log and apply version edit: %w", err)
}
// 6. 删除对应的 WAL
e.walManager.Delete(walNumber)
// 7. 从 Immutable 列表中移除
e.memtableManager.RemoveImmutable(imm)
// 8. 持久化索引(防止崩溃丢失索引数据)
if e.indexManager != nil {
e.indexManager.BuildAll()
}
// 9. Compaction 由后台线程负责,不在 flush 路径中触发
// 避免同步 compaction 导致刚创建的文件立即被删除
// e.compactionManager.MaybeCompact()
return nil
}
// recover 恢复数据
func (e *Engine) recover() error {
// 1. 恢复 SST 文件SST Manager 已经在 NewManager 中恢复了)
// 只需要获取最大 seq
maxSeq := e.sstManager.GetMaxSeq()
if maxSeq > e.seq.Load() {
e.seq.Store(maxSeq)
}
// 2. 恢复所有 WAL 文件到 MemTable Manager
walDir := filepath.Join(e.dir, "wal")
pattern := filepath.Join(walDir, "*.wal")
walFiles, err := filepath.Glob(pattern)
if err == nil && len(walFiles) > 0 {
// 按文件名排序
sort.Strings(walFiles)
// 依次读取每个 WAL
for _, walPath := range walFiles {
reader, err := NewWALReader(walPath)
if err != nil {
continue
}
entries, err := reader.Read()
reader.Close()
if err != nil {
continue
}
// 重放 WAL 到 Active MemTable
for _, entry := range entries {
// 如果定义了 Schema验证数据
if e.schema != nil {
var row SSTableRow
if err := json.Unmarshal(entry.Data, &row); err != nil {
return fmt.Errorf("failed to unmarshal row during recovery (seq=%d): %w", entry.Seq, err)
}
// 验证 Schema
if err := e.schema.Validate(row.Data); err != nil {
return fmt.Errorf("schema validation failed during recovery (seq=%d): %w", entry.Seq, err)
}
}
e.memtableManager.Put(entry.Seq, entry.Data)
if entry.Seq > e.seq.Load() {
e.seq.Store(entry.Seq)
}
}
}
}
return nil
}
// autoFlushMonitor 自动 flush 监控
func (e *Engine) autoFlushMonitor() {
ticker := time.NewTicker(e.autoFlushTimeout / 2) // 每半个超时时间检查一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 检查是否超时
lastWrite := time.Unix(0, e.lastWriteTime.Load())
if time.Since(lastWrite) >= e.autoFlushTimeout {
// 检查 MemTable 是否有数据
active := e.memtableManager.GetActive()
if active != nil && active.Size() > 0 {
// 触发 flush
e.Flush()
}
}
case <-e.stopAutoFlush:
return
}
}
}
// Flush 手动刷新 Active MemTable 到磁盘
func (e *Engine) Flush() error {
// 检查 Active MemTable 是否有数据
active := e.memtableManager.GetActive()
if active == nil || active.Size() == 0 {
return nil // 没有数据,无需 flush
}
// 强制切换 MemTableswitchMemTable 内部有锁)
return e.switchMemTable()
}
// Close 关闭引擎
func (e *Engine) Close() error {
// 1. 停止自动 flush 监控(如果还在运行)
if e.stopAutoFlush != nil {
select {
case <-e.stopAutoFlush:
// 已经关闭,跳过
default:
close(e.stopAutoFlush)
}
}
// 2. 停止 Compaction Manager
if e.compactionManager != nil {
e.compactionManager.Stop()
}
// 3. 刷新 Active MemTable确保所有数据都写入磁盘
// 检查 memtableManager 是否存在(可能已被 Destroy
if e.memtableManager != nil {
e.Flush()
}
// 3. 关闭 WAL Manager
if e.walManager != nil {
e.walManager.Close()
}
// 4. 等待所有 Immutable Flush 完成
// TODO: 添加更优雅的等待机制
if e.memtableManager != nil {
for e.memtableManager.GetImmutableCount() > 0 {
time.Sleep(100 * time.Millisecond)
}
}
// 5. 保存所有索引
if e.indexManager != nil {
e.indexManager.BuildAll()
e.indexManager.Close()
}
// 6. 关闭 VersionSet
if e.versionSet != nil {
e.versionSet.Close()
}
// 7. 关闭 WAL Manager
if e.walManager != nil {
e.walManager.Close()
}
// 6. 关闭 SST Manager
if e.sstManager != nil {
e.sstManager.Close()
}
return nil
}
// Clean 清除所有数据(保留 Engine 可用)
func (e *Engine) Clean() error {
e.flushMu.Lock()
defer e.flushMu.Unlock()
// 0. 停止自动 flush 监控(临时)
if e.stopAutoFlush != nil {
close(e.stopAutoFlush)
}
// 1. 停止 Compaction Manager
if e.compactionManager != nil {
e.compactionManager.Stop()
}
// 2. 等待所有 Immutable Flush 完成
for e.memtableManager.GetImmutableCount() > 0 {
time.Sleep(100 * time.Millisecond)
}
// 3. 清空 MemTable
e.memtableManager = NewMemTableManager(DefaultMemTableSize)
// 2. 删除所有 WAL 文件
if e.walManager != nil {
e.walManager.Close()
walDir := filepath.Join(e.dir, "wal")
os.RemoveAll(walDir)
os.MkdirAll(walDir, 0755)
// 重新创建 WAL Manager
walMgr, err := NewWALManager(walDir)
if err != nil {
return fmt.Errorf("recreate wal manager: %w", err)
}
e.walManager = walMgr
e.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber())
}
// 3. 删除所有 SST 文件
if e.sstManager != nil {
e.sstManager.Close()
sstDir := filepath.Join(e.dir, "sst")
os.RemoveAll(sstDir)
os.MkdirAll(sstDir, 0755)
// 重新创建 SST Manager
sstMgr, err := NewSSTableManager(sstDir)
if err != nil {
return fmt.Errorf("recreate sst manager: %w", err)
}
e.sstManager = sstMgr
}
// 4. 删除所有索引文件
if e.indexManager != nil {
e.indexManager.Close()
indexFiles, _ := filepath.Glob(filepath.Join(e.dir, "idx_*.sst"))
for _, f := range indexFiles {
os.Remove(f)
}
// 重新创建 Index Manager
if e.schema != nil {
e.indexManager = NewIndexManager(e.dir, e.schema)
}
}
// 5. 重置 MANIFEST
if e.versionSet != nil {
e.versionSet.Close()
manifestDir := e.dir
os.Remove(filepath.Join(manifestDir, "MANIFEST"))
os.Remove(filepath.Join(manifestDir, "CURRENT"))
// 重新创建 VersionSet
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
return fmt.Errorf("recreate version set: %w", err)
}
e.versionSet = versionSet
}
// 6. 重新创建 Compaction Manager
sstDir := filepath.Join(e.dir, "sst")
e.compactionManager = NewCompactionManager(sstDir, e.versionSet, e.sstManager)
if e.schema != nil {
e.compactionManager.SetSchema(e.schema)
}
e.compactionManager.Start()
// 7. 重置序列号
e.seq.Store(0)
// 8. 更新最后写入时间
e.lastWriteTime.Store(time.Now().UnixNano())
// 9. 重启自动 flush 监控
e.stopAutoFlush = make(chan struct{})
go e.autoFlushMonitor()
return nil
}
// Destroy 销毁 Engine 并删除所有数据文件
func (e *Engine) Destroy() error {
// 1. 先关闭 Engine
if err := e.Close(); err != nil {
return fmt.Errorf("close engine: %w", err)
}
// 2. 删除整个数据目录
if err := os.RemoveAll(e.dir); err != nil {
return fmt.Errorf("remove data directory: %w", err)
}
// 3. 标记 Engine 为不可用(将所有管理器设为 nil
e.walManager = nil
e.sstManager = nil
e.memtableManager = nil
e.versionSet = nil
e.compactionManager = nil
e.indexManager = nil
return nil
}
// TableStats 统计信息
type TableStats struct {
MemTableSize int64
MemTableCount int
SSTCount int
TotalRows int64
}
// GetVersionSet 获取 VersionSet用于高级操作
func (e *Engine) GetVersionSet() *VersionSet {
return e.versionSet
}
// GetCompactionManager 获取 Compaction Manager用于高级操作
func (e *Engine) GetCompactionManager() *CompactionManager {
return e.compactionManager
}
// GetMemtableManager 获取 Memtable Manager
func (e *Engine) GetMemtableManager() *MemTableManager {
return e.memtableManager
}
// GetSSTManager 获取 SST Manager
func (e *Engine) GetSSTManager() *SSTableManager {
return e.sstManager
}
// GetMaxSeq 获取当前最大的 seq 号
func (e *Engine) GetMaxSeq() int64 {
return e.seq.Load() - 1 // seq 是下一个要分配的,所以最大的是 seq - 1
}
// GetSchema 获取 Schema
func (e *Engine) GetSchema() *Schema {
return e.schema
}
// Stats 获取统计信息
func (e *Engine) Stats() *TableStats {
memStats := e.memtableManager.GetStats()
sstStats := e.sstManager.GetStats()
stats := &TableStats{
MemTableSize: memStats.TotalSize,
MemTableCount: memStats.TotalCount,
SSTCount: sstStats.FileCount,
}
// 计算总行数
stats.TotalRows = int64(memStats.TotalCount)
readers := e.sstManager.GetReaders()
for _, reader := range readers {
header := reader.GetHeader()
stats.TotalRows += header.RowCount
}
return stats
}
// CreateIndex 创建索引
func (e *Engine) CreateIndex(field string) error {
if e.indexManager == nil {
return fmt.Errorf("no schema defined, cannot create index")
}
return e.indexManager.CreateIndex(field)
}
// DropIndex 删除索引
func (e *Engine) DropIndex(field string) error {
if e.indexManager == nil {
return fmt.Errorf("no schema defined, cannot drop index")
}
return e.indexManager.DropIndex(field)
}
// ListIndexes 列出所有索引
func (e *Engine) ListIndexes() []string {
if e.indexManager == nil {
return nil
}
return e.indexManager.ListIndexes()
}
// GetIndexMetadata 获取索引元数据
func (e *Engine) GetIndexMetadata() map[string]IndexMetadata {
if e.indexManager == nil {
return nil
}
return e.indexManager.GetIndexMetadata()
}
// RepairIndexes 手动修复索引
func (e *Engine) RepairIndexes() error {
return e.verifyAndRepairIndexes()
}
// Query 创建查询构建器
func (e *Engine) Query() *QueryBuilder {
return newQueryBuilder(e)
}
// verifyAndRepairIndexes 验证并修复索引
func (e *Engine) verifyAndRepairIndexes() error {
if e.indexManager == nil {
return nil
}
// 获取当前最大 seq
currentMaxSeq := e.seq.Load()
// 创建 getData 函数
getData := func(seq int64) (map[string]any, error) {
row, err := e.Get(seq)
if err != nil {
return nil, err
}
return row.Data, nil
}
// 验证并修复
return e.indexManager.VerifyAndRepair(currentMaxSeq, getData)
}

View File

@@ -1,244 +0,0 @@
package srdb
import (
"os"
"testing"
"time"
)
func TestEngineClean(t *testing.T) {
dir := "./test_clean_data"
defer os.RemoveAll(dir)
// 1. 创建 Engine 并插入数据
engine, err := OpenEngine(&EngineOptions{
Dir: dir,
})
if err != nil {
t.Fatal(err)
}
// 插入一些数据
for i := 0; i < 100; i++ {
err := engine.Insert(map[string]any{
"id": i,
"name": "test",
})
if err != nil {
t.Fatal(err)
}
}
// 强制 flush
engine.Flush()
time.Sleep(500 * time.Millisecond)
// 验证数据存在
stats := engine.Stats()
t.Logf("Before Clean: MemTable=%d, SST=%d, Total=%d",
stats.MemTableCount, stats.SSTCount, stats.TotalRows)
if stats.TotalRows == 0 {
t.Errorf("Expected some rows, got 0")
}
// 2. 清除数据
err = engine.Clean()
if err != nil {
t.Fatal(err)
}
// 3. 验证数据已清除
stats = engine.Stats()
t.Logf("After Clean: MemTable=%d, SST=%d, Total=%d",
stats.MemTableCount, stats.SSTCount, stats.TotalRows)
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows)
}
// 4. 验证 Engine 仍然可用
err = engine.Insert(map[string]any{
"id": 1,
"name": "after_clean",
})
if err != nil {
t.Fatal(err)
}
stats = engine.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row after insert, got %d", stats.TotalRows)
}
engine.Close()
}
func TestEngineDestroy(t *testing.T) {
dir := "./test_destroy_data"
defer os.RemoveAll(dir)
// 1. 创建 Engine 并插入数据
engine, err := OpenEngine(&EngineOptions{
Dir: dir,
})
if err != nil {
t.Fatal(err)
}
// 插入一些数据
for i := 0; i < 50; i++ {
err := engine.Insert(map[string]any{
"id": i,
"name": "test",
})
if err != nil {
t.Fatal(err)
}
}
// 验证数据存在
stats := engine.Stats()
t.Logf("Before Destroy: MemTable=%d, SST=%d, Total=%d",
stats.MemTableCount, stats.SSTCount, stats.TotalRows)
// 2. 销毁 Engine
err = engine.Destroy()
if err != nil {
t.Fatal(err)
}
// 3. 验证数据目录已删除
if _, err := os.Stat(dir); !os.IsNotExist(err) {
t.Errorf("Data directory should be deleted")
}
// 4. 验证 Engine 不可用(尝试插入会失败)
err = engine.Insert(map[string]any{
"id": 1,
"name": "after_destroy",
})
if err == nil {
t.Errorf("Insert should fail after destroy")
}
}
func TestEngineCleanWithSchema(t *testing.T) {
dir := "./test_clean_schema_data"
defer os.RemoveAll(dir)
// 定义 Schema
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
// 1. 创建 Engine 并插入数据
engine, err := OpenEngine(&EngineOptions{
Dir: dir,
Schema: schema,
})
if err != nil {
t.Fatal(err)
}
// 创建索引
err = engine.CreateIndex("id")
if err != nil {
t.Fatal(err)
}
// 插入数据
for i := 0; i < 50; i++ {
err := engine.Insert(map[string]any{
"id": int64(i),
"name": "test",
})
if err != nil {
t.Fatal(err)
}
}
// 验证索引存在
indexes := engine.ListIndexes()
if len(indexes) != 1 {
t.Errorf("Expected 1 index, got %d", len(indexes))
}
// 2. 清除数据
err = engine.Clean()
if err != nil {
t.Fatal(err)
}
// 3. 验证数据已清除但 Schema 和索引结构保留
stats := engine.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows)
}
// 验证可以继续插入Schema 仍然有效)
err = engine.Insert(map[string]any{
"id": int64(100),
"name": "after_clean",
})
if err != nil {
t.Fatal(err)
}
engine.Close()
}
func TestEngineCleanAndReopen(t *testing.T) {
dir := "./test_clean_reopen_data"
defer os.RemoveAll(dir)
// 1. 创建 Engine 并插入数据
engine, err := OpenEngine(&EngineOptions{
Dir: dir,
})
if err != nil {
t.Fatal(err)
}
for i := 0; i < 100; i++ {
engine.Insert(map[string]any{
"id": i,
"name": "test",
})
}
// 2. 清除数据
engine.Clean()
// 3. 关闭并重新打开
engine.Close()
engine2, err := OpenEngine(&EngineOptions{
Dir: dir,
})
if err != nil {
t.Fatal(err)
}
defer engine2.Close()
// 4. 验证数据为空
stats := engine2.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after reopen, got %d", stats.TotalRows)
}
// 5. 验证可以插入新数据
err = engine2.Insert(map[string]any{
"id": 1,
"name": "new_data",
})
if err != nil {
t.Fatal(err)
}
stats = engine2.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row, got %d", stats.TotalRows)
}
}

336
errors.go Normal file
View File

@@ -0,0 +1,336 @@
package srdb
import (
"errors"
"fmt"
)
// ErrCode 错误码类型
type ErrCode int
// 错误码定义
const (
// 通用错误 (1000-1999)
ErrCodeNotFound ErrCode = 1000 // 数据未找到
ErrCodeClosed ErrCode = 1001 // 对象已关闭
ErrCodeInvalidData ErrCode = 1002 // 无效数据
ErrCodeCorrupted ErrCode = 1003 // 数据损坏
ErrCodeExists ErrCode = 1004 // 对象已存在
ErrCodeInvalidParam ErrCode = 1005 // 无效参数
// 数据库错误 (2000-2999)
ErrCodeDatabaseNotFound ErrCode = 2000 // 数据库不存在
ErrCodeDatabaseExists ErrCode = 2001 // 数据库已存在
ErrCodeDatabaseClosed ErrCode = 2002 // 数据库已关闭
// 表错误 (3000-3999)
ErrCodeTableNotFound ErrCode = 3000 // 表不存在
ErrCodeTableExists ErrCode = 3001 // 表已存在
ErrCodeTableClosed ErrCode = 3002 // 表已关闭
// Schema 错误 (4000-4999)
ErrCodeSchemaNotFound ErrCode = 4000 // Schema 不存在
ErrCodeSchemaInvalid ErrCode = 4001 // Schema 无效
ErrCodeSchemaMismatch ErrCode = 4002 // Schema 不匹配
ErrCodeSchemaValidationFailed ErrCode = 4003 // Schema 验证失败
ErrCodeSchemaChecksumMismatch ErrCode = 4004 // Schema 校验和不匹配
// 字段错误 (5000-5999)
ErrCodeFieldNotFound ErrCode = 5000 // 字段不存在
ErrCodeFieldTypeMismatch ErrCode = 5001 // 字段类型不匹配
ErrCodeFieldRequired ErrCode = 5002 // 必填字段缺失
// 索引错误 (6000-6999)
ErrCodeIndexNotFound ErrCode = 6000 // 索引不存在
ErrCodeIndexExists ErrCode = 6001 // 索引已存在
ErrCodeIndexNotReady ErrCode = 6002 // 索引未就绪
ErrCodeIndexCorrupted ErrCode = 6003 // 索引损坏
// 文件格式错误 (7000-7999)
ErrCodeInvalidFormat ErrCode = 7000 // 无效的文件格式
ErrCodeUnsupportedVersion ErrCode = 7001 // 不支持的版本
ErrCodeInvalidMagicNumber ErrCode = 7002 // 无效的魔数
ErrCodeChecksumMismatch ErrCode = 7003 // 校验和不匹配
// WAL 错误 (8000-8999)
ErrCodeWALCorrupted ErrCode = 8000 // WAL 文件损坏
ErrCodeWALClosed ErrCode = 8001 // WAL 已关闭
// SSTable 错误 (9000-9999)
ErrCodeSSTableNotFound ErrCode = 9000 // SSTable 文件不存在
ErrCodeSSTableCorrupted ErrCode = 9001 // SSTable 文件损坏
// Compaction 错误 (10000-10999)
ErrCodeCompactionInProgress ErrCode = 10000 // Compaction 正在进行
ErrCodeNoCompactionNeeded ErrCode = 10001 // 不需要 Compaction
// 编解码错误 (11000-11999)
ErrCodeEncodeFailed ErrCode = 11000 // 编码失败
ErrCodeDecodeFailed ErrCode = 11001 // 解码失败
)
// 错误码消息映射
var errCodeMessages = map[ErrCode]string{
// 通用错误
ErrCodeNotFound: "not found",
ErrCodeClosed: "already closed",
ErrCodeInvalidData: "invalid data",
ErrCodeCorrupted: "data corrupted",
ErrCodeExists: "already exists",
ErrCodeInvalidParam: "invalid parameter",
// 数据库错误
ErrCodeDatabaseNotFound: "database not found",
ErrCodeDatabaseExists: "database already exists",
ErrCodeDatabaseClosed: "database closed",
// 表错误
ErrCodeTableNotFound: "table not found",
ErrCodeTableExists: "table already exists",
ErrCodeTableClosed: "table closed",
// Schema 错误
ErrCodeSchemaNotFound: "schema not found",
ErrCodeSchemaInvalid: "schema invalid",
ErrCodeSchemaMismatch: "schema mismatch",
ErrCodeSchemaValidationFailed: "schema validation failed",
ErrCodeSchemaChecksumMismatch: "schema checksum mismatch",
// 字段错误
ErrCodeFieldNotFound: "field not found",
ErrCodeFieldTypeMismatch: "field type mismatch",
ErrCodeFieldRequired: "required field missing",
// 索引错误
ErrCodeIndexNotFound: "index not found",
ErrCodeIndexExists: "index already exists",
ErrCodeIndexNotReady: "index not ready",
ErrCodeIndexCorrupted: "index corrupted",
// 文件格式错误
ErrCodeInvalidFormat: "invalid file format",
ErrCodeUnsupportedVersion: "unsupported version",
ErrCodeInvalidMagicNumber: "invalid magic number",
ErrCodeChecksumMismatch: "checksum mismatch",
// WAL 错误
ErrCodeWALCorrupted: "wal corrupted",
ErrCodeWALClosed: "wal closed",
// SSTable 错误
ErrCodeSSTableNotFound: "sstable not found",
ErrCodeSSTableCorrupted: "sstable corrupted",
// Compaction 错误
ErrCodeCompactionInProgress: "compaction in progress",
ErrCodeNoCompactionNeeded: "no compaction needed",
// 编解码错误
ErrCodeEncodeFailed: "encode failed",
ErrCodeDecodeFailed: "decode failed",
}
// Error 错误类型
type Error struct {
Code ErrCode // 错误码
Message string // 错误消息
Cause error // 原始错误
}
// Error 实现 error 接口
func (e *Error) Error() string {
if e.Cause != nil {
return fmt.Sprintf("[%d] %s: %v", e.Code, e.Message, e.Cause)
}
return fmt.Sprintf("[%d] %s", e.Code, e.Message)
}
// Unwrap 支持 errors.Is 和 errors.As
func (e *Error) Unwrap() error {
return e.Cause
}
// Is 判断错误码是否相同
func (e *Error) Is(target error) bool {
t, ok := target.(*Error)
if !ok {
return false
}
return e.Code == t.Code
}
// NewError 创建新错误
func NewError(code ErrCode, cause error) *Error {
msg, ok := errCodeMessages[code]
if !ok {
msg = "unknown error"
}
return &Error{
Code: code,
Message: msg,
Cause: cause,
}
}
// NewErrorf 创建带格式化消息的错误
// 注意:如果 args 中最后一个参数是 error 类型,它会被设置为 Cause
func NewErrorf(code ErrCode, format string, args ...any) *Error {
var cause error
// 检查最后一个参数是否为 error
if len(args) > 0 {
if err, ok := args[len(args)-1].(error); ok {
cause = err
// 从 args 中移除最后一个 error 参数
args = args[:len(args)-1]
}
}
return &Error{
Code: code,
Message: fmt.Sprintf(format, args...),
Cause: cause,
}
}
// 预定义的常用错误(向后兼容)
var (
// ErrNotFound 数据未找到
ErrNotFound = NewError(ErrCodeNotFound, nil)
// ErrClosed 对象已关闭
ErrClosed = NewError(ErrCodeClosed, nil)
// ErrInvalidData 无效数据
ErrInvalidData = NewError(ErrCodeInvalidData, nil)
// ErrCorrupted 数据损坏
ErrCorrupted = NewError(ErrCodeCorrupted, nil)
)
// 数据库错误(向后兼容)
var (
ErrDatabaseNotFound = NewError(ErrCodeDatabaseNotFound, nil)
ErrDatabaseExists = NewError(ErrCodeDatabaseExists, nil)
ErrDatabaseClosed = NewError(ErrCodeDatabaseClosed, nil)
)
// 表错误(向后兼容)
var (
ErrTableNotFound = NewError(ErrCodeTableNotFound, nil)
ErrTableExists = NewError(ErrCodeTableExists, nil)
ErrTableClosed = NewError(ErrCodeTableClosed, nil)
)
// Schema 错误(向后兼容)
var (
ErrSchemaNotFound = NewError(ErrCodeSchemaNotFound, nil)
ErrSchemaInvalid = NewError(ErrCodeSchemaInvalid, nil)
ErrSchemaMismatch = NewError(ErrCodeSchemaMismatch, nil)
ErrSchemaValidationFailed = NewError(ErrCodeSchemaValidationFailed, nil)
ErrSchemaChecksumMismatch = NewError(ErrCodeSchemaChecksumMismatch, nil)
)
// 字段错误(向后兼容)
var (
ErrFieldNotFound = NewError(ErrCodeFieldNotFound, nil)
ErrFieldTypeMismatch = NewError(ErrCodeFieldTypeMismatch, nil)
ErrFieldRequired = NewError(ErrCodeFieldRequired, nil)
)
// 索引错误(向后兼容)
var (
ErrIndexNotFound = NewError(ErrCodeIndexNotFound, nil)
ErrIndexExists = NewError(ErrCodeIndexExists, nil)
ErrIndexNotReady = NewError(ErrCodeIndexNotReady, nil)
ErrIndexCorrupted = NewError(ErrCodeIndexCorrupted, nil)
)
// 文件格式错误(向后兼容)
var (
ErrInvalidFormat = NewError(ErrCodeInvalidFormat, nil)
ErrUnsupportedVersion = NewError(ErrCodeUnsupportedVersion, nil)
ErrInvalidMagicNumber = NewError(ErrCodeInvalidMagicNumber, nil)
ErrChecksumMismatch = NewError(ErrCodeChecksumMismatch, nil)
)
// WAL 错误(向后兼容)
var (
ErrWALCorrupted = NewError(ErrCodeWALCorrupted, nil)
ErrWALClosed = NewError(ErrCodeWALClosed, nil)
)
// SSTable 错误(向后兼容)
var (
ErrSSTableNotFound = NewError(ErrCodeSSTableNotFound, nil)
ErrSSTableCorrupted = NewError(ErrCodeSSTableCorrupted, nil)
)
// Compaction 错误(向后兼容)
var (
ErrCompactionInProgress = NewError(ErrCodeCompactionInProgress, nil)
ErrNoCompactionNeeded = NewError(ErrCodeNoCompactionNeeded, nil)
)
// 编解码错误(向后兼容)
var (
ErrEncodeFailed = NewError(ErrCodeEncodeFailed, nil)
ErrDecodeFailed = NewError(ErrCodeDecodeFailed, nil)
)
// 辅助函数
// GetErrorCode 获取错误码
func GetErrorCode(err error) ErrCode {
var e *Error
if errors.As(err, &e) {
return e.Code
}
return 0
}
// IsError 判断错误是否匹配指定的错误码
func IsError(err error, code ErrCode) bool {
return GetErrorCode(err) == code
}
// IsNotFound 判断是否是 NotFound 错误
func IsNotFound(err error) bool {
code := GetErrorCode(err)
return code == ErrCodeNotFound ||
code == ErrCodeTableNotFound ||
code == ErrCodeDatabaseNotFound ||
code == ErrCodeIndexNotFound ||
code == ErrCodeFieldNotFound ||
code == ErrCodeSchemaNotFound ||
code == ErrCodeSSTableNotFound
}
// IsCorrupted 判断是否是数据损坏错误
func IsCorrupted(err error) bool {
code := GetErrorCode(err)
return code == ErrCodeCorrupted ||
code == ErrCodeWALCorrupted ||
code == ErrCodeSSTableCorrupted ||
code == ErrCodeIndexCorrupted ||
code == ErrCodeChecksumMismatch ||
code == ErrCodeSchemaChecksumMismatch
}
// IsClosed 判断是否是已关闭错误
func IsClosed(err error) bool {
code := GetErrorCode(err)
return code == ErrCodeClosed ||
code == ErrCodeDatabaseClosed ||
code == ErrCodeTableClosed ||
code == ErrCodeWALClosed
}
// WrapError 包装错误并添加上下文
func WrapError(err error, format string, args ...any) error {
if err == nil {
return nil
}
msg := fmt.Sprintf(format, args...)
return fmt.Errorf("%s: %w", msg, err)
}

185
errors_test.go Normal file
View File

@@ -0,0 +1,185 @@
package srdb
import (
"errors"
"fmt"
"testing"
)
func TestIsNotFound(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{"ErrNotFound", ErrNotFound, true},
{"ErrTableNotFound", ErrTableNotFound, true},
{"ErrDatabaseNotFound", ErrDatabaseNotFound, true},
{"ErrIndexNotFound", ErrIndexNotFound, true},
{"ErrFieldNotFound", ErrFieldNotFound, true},
{"ErrSchemaNotFound", ErrSchemaNotFound, true},
{"ErrSSTableNotFound", ErrSSTableNotFound, true},
{"ErrTableExists", ErrTableExists, false},
{"ErrCorrupted", ErrCorrupted, false},
{"nil", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsNotFound(tt.err)
if result != tt.expected {
t.Errorf("IsNotFound(%v) = %v, want %v", tt.err, result, tt.expected)
}
})
}
}
func TestIsCorrupted(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{"ErrCorrupted", ErrCorrupted, true},
{"ErrWALCorrupted", ErrWALCorrupted, true},
{"ErrSSTableCorrupted", ErrSSTableCorrupted, true},
{"ErrIndexCorrupted", ErrIndexCorrupted, true},
{"ErrChecksumMismatch", ErrChecksumMismatch, true},
{"ErrSchemaChecksumMismatch", ErrSchemaChecksumMismatch, true},
{"ErrNotFound", ErrNotFound, false},
{"nil", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsCorrupted(tt.err)
if result != tt.expected {
t.Errorf("IsCorrupted(%v) = %v, want %v", tt.err, result, tt.expected)
}
})
}
}
func TestIsClosed(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{"ErrClosed", ErrClosed, true},
{"ErrDatabaseClosed", ErrDatabaseClosed, true},
{"ErrTableClosed", ErrTableClosed, true},
{"ErrWALClosed", ErrWALClosed, true},
{"ErrNotFound", ErrNotFound, false},
{"nil", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsClosed(tt.err)
if result != tt.expected {
t.Errorf("IsClosed(%v) = %v, want %v", tt.err, result, tt.expected)
}
})
}
}
func TestWrapError(t *testing.T) {
baseErr := ErrNotFound
wrapped := WrapError(baseErr, "failed to get table %s", "users")
if wrapped == nil {
t.Fatal("WrapError returned nil")
}
// 验证包装后的错误仍然是原始错误
if !errors.Is(wrapped, ErrNotFound) {
t.Error("Wrapped error should be ErrNotFound")
}
// 验证错误消息包含上下文
errMsg := wrapped.Error()
if errMsg != "failed to get table users: [1000] not found" {
t.Errorf("Expected error message to contain context, got: %s", errMsg)
}
}
func TestWrapErrorNil(t *testing.T) {
wrapped := WrapError(nil, "some context")
if wrapped != nil {
t.Error("WrapError(nil) should return nil")
}
}
func TestNewError(t *testing.T) {
err := NewErrorf(ErrCodeInvalidData, "custom error: %s", "test")
if err == nil {
t.Fatal("NewErrorf returned nil")
}
// 验证错误码
if err.Code != ErrCodeInvalidData {
t.Errorf("Expected code %d, got %d", ErrCodeInvalidData, err.Code)
}
// 验证错误消息
expected := "custom error: test"
if err.Message != expected {
t.Errorf("Expected message %q, got %q", expected, err.Message)
}
}
func TestGetErrorCode(t *testing.T) {
err := NewError(ErrCodeTableNotFound, nil)
code := GetErrorCode(err)
if code != ErrCodeTableNotFound {
t.Errorf("Expected code %d, got %d", ErrCodeTableNotFound, code)
}
// 测试非 Error 类型
stdErr := fmt.Errorf("standard error")
code = GetErrorCode(stdErr)
if code != 0 {
t.Errorf("Expected code 0 for standard error, got %d", code)
}
}
func TestIsError(t *testing.T) {
tests := []struct {
name string
err error
code ErrCode
expected bool
}{
{"exact match", NewError(ErrCodeTableNotFound, nil), ErrCodeTableNotFound, true},
{"no match", NewError(ErrCodeTableNotFound, nil), ErrCodeDatabaseNotFound, false},
{"wrapped error", WrapError(ErrTableNotFound, "context"), ErrCodeTableNotFound, true},
{"standard error", fmt.Errorf("standard error"), ErrCodeTableNotFound, false},
{"nil error", nil, ErrCodeTableNotFound, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsError(tt.err, tt.code)
if result != tt.expected {
t.Errorf("IsError(%v, %d) = %v, want %v", tt.err, tt.code, result, tt.expected)
}
})
}
}
func TestErrorWrapping(t *testing.T) {
// 测试多层包装
err1 := ErrTableNotFound
err2 := WrapError(err1, "database %s", "mydb")
err3 := WrapError(err2, "operation failed")
// 验证仍然能识别原始错误
if !IsNotFound(err3) {
t.Error("Should recognize wrapped error as NotFound")
}
if !errors.Is(err3, ErrTableNotFound) {
t.Error("Should be able to unwrap to ErrTableNotFound")
}
}

View File

@@ -327,8 +327,8 @@ package main
import (
"net/http"
"code.tczkiot.com/srdb"
"code.tczkiot.com/srdb/webui"
"code.tczkiot.com/wlw/srdb"
"code.tczkiot.com/wlw/srdb/webui"
)
func main() {
@@ -362,7 +362,7 @@ http.ListenAndServe(":8080", mux)
将 webui 工具的命令集成到你的应用:
```go
import "code.tczkiot.com/srdb/examples/webui/commands"
import "code.tczkiot.com/wlw/srdb/examples/webui/commands"
// 检查数据
commands.CheckData("./mydb")

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"log"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// CheckData 检查数据库中的数据

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"log"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// CheckSeq 检查特定序列号的数据

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"log"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// DumpManifest 导出 manifest 信息
@@ -20,12 +20,11 @@ func DumpManifest(dbPath string) {
log.Fatal(err)
}
engine := table.GetEngine()
versionSet := engine.GetVersionSet()
versionSet := table.GetVersionSet()
version := versionSet.GetCurrent()
// Check for duplicates in each level
for level := 0; level < 7; level++ {
for level := range 7 {
files := version.GetLevel(level)
if len(files) == 0 {
continue

View File

@@ -9,7 +9,7 @@ import (
"strconv"
"strings"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// InspectAllSST 检查所有 SST 文件

View File

@@ -5,7 +5,7 @@ import (
"log"
"os"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// InspectSST 检查特定 SST 文件

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"log"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// TestFix 测试修复

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"log"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
// TestKeys 测试键

View File

@@ -9,8 +9,8 @@ import (
"slices"
"time"
"code.tczkiot.com/srdb"
"code.tczkiot.com/srdb/webui"
"code.tczkiot.com/wlw/srdb"
"code.tczkiot.com/wlw/srdb/webui"
)
// StartWebUI 启动 WebUI 服务器

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"os"
"code.tczkiot.com/srdb/examples/webui/commands"
"code.tczkiot.com/wlw/srdb/examples/webui/commands"
)
func main() {

7
go.mod
View File

@@ -1,10 +1,7 @@
module code.tczkiot.com/srdb
module code.tczkiot.com/wlw/srdb
go 1.24.0
require (
github.com/edsrzf/mmap-go v1.1.0
github.com/golang/snappy v1.0.0
)
require github.com/edsrzf/mmap-go v1.1.0
require golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect

2
go.sum
View File

@@ -1,6 +1,4 @@
github.com/edsrzf/mmap-go v1.1.0 h1:6EUwBLQ/Mcr1EYLE4Tn1VdW1A4ckqCQWZBw8Hr0kjpQ=
github.com/edsrzf/mmap-go v1.1.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

241
index.go
View File

@@ -4,6 +4,7 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"sync"
@@ -22,16 +23,16 @@ type IndexMetadata struct {
// SecondaryIndex 二级索引
type SecondaryIndex struct {
name string // 索引名称
field string // 字段名
fieldType FieldType // 字段类型
file *os.File // 索引文件
builder *BTreeBuilder // B+Tree 构建
reader *BTreeReader // B+Tree 读取器
valueToSeq map[string][]int64 // 值 → seq 列表 (构建时使用)
metadata IndexMetadata // 元数据
mu sync.RWMutex
ready bool // 索引是否就绪
name string // 索引名称
field string // 字段名
fieldType FieldType // 字段类型
file *os.File // 索引文件
btreeReader *IndexBTreeReader // B+Tree 读取
valueToSeq map[string][]int64 // 值 → seq 列表 (构建时使用)
metadata IndexMetadata // 元数据
mu sync.RWMutex
ready bool // 索引是否就绪
useBTree bool // 是否使用 B+Tree 存储(新格式)
}
// NewSecondaryIndex 创建二级索引
@@ -52,7 +53,7 @@ func NewSecondaryIndex(dir, field string, fieldType FieldType) (*SecondaryIndex,
}, nil
}
// Add 添加索引条目
// Add 添加索引条目(增量更新元数据)
func (idx *SecondaryIndex) Add(value any, seq int64) error {
idx.mu.Lock()
defer idx.mu.Unlock()
@@ -61,97 +62,59 @@ func (idx *SecondaryIndex) Add(value any, seq int64) error {
key := fmt.Sprintf("%v", value)
idx.valueToSeq[key] = append(idx.valueToSeq[key], seq)
// 增量更新元数据 O(1)
if idx.metadata.MinSeq == 0 || seq < idx.metadata.MinSeq {
idx.metadata.MinSeq = seq
}
if seq > idx.metadata.MaxSeq {
idx.metadata.MaxSeq = seq
}
idx.metadata.RowCount++
idx.metadata.UpdatedAt = time.Now().UnixNano()
// 首次添加时设置 CreatedAt
if idx.metadata.CreatedAt == 0 {
idx.metadata.CreatedAt = time.Now().UnixNano()
}
return nil
}
// Build 构建索引并持久化
// Build 构建索引并持久化B+Tree 格式)
func (idx *SecondaryIndex) Build() error {
idx.mu.Lock()
defer idx.mu.Unlock()
// 持久化索引数据到 JSON 文件
return idx.save()
}
// save 保存索引到磁盘
func (idx *SecondaryIndex) save() error {
// 更新元数据
idx.updateMetadata()
// 创建包含元数据的数据结构
indexData := struct {
Metadata IndexMetadata `json:"metadata"`
ValueToSeq map[string][]int64 `json:"data"`
}{
Metadata: idx.metadata,
ValueToSeq: idx.valueToSeq,
}
// 序列化索引数据
data, err := json.Marshal(indexData)
if err != nil {
return err
}
// 元数据已在 Add 时增量更新,这里只更新版本号
idx.metadata.Version++
idx.metadata.UpdatedAt = time.Now().UnixNano()
// Truncate 文件
err = idx.file.Truncate(0)
err := idx.file.Truncate(0)
if err != nil {
return err
}
// 写入文件
_, err = idx.file.Seek(0, 0)
if err != nil {
return err
// 使用 B+Tree 写入器
writer := NewIndexBTreeWriter(idx.file, idx.metadata)
// 添加所有条目
for value, seqs := range idx.valueToSeq {
writer.Add(value, seqs)
}
_, err = idx.file.Write(data)
// 构建并写入
err = writer.Build()
if err != nil {
return err
}
// Sync 到磁盘
err = idx.file.Sync()
if err != nil {
return err
return fmt.Errorf("failed to build btree index: %w", err)
}
idx.useBTree = true
idx.ready = true
return nil
}
// updateMetadata 更新元数据
func (idx *SecondaryIndex) updateMetadata() {
now := time.Now().UnixNano()
if idx.metadata.CreatedAt == 0 {
idx.metadata.CreatedAt = now
}
idx.metadata.UpdatedAt = now
idx.metadata.Version++
// 计算 MinSeq, MaxSeq, RowCount
var minSeq, maxSeq int64 = -1, -1
rowCount := int64(0)
for _, seqs := range idx.valueToSeq {
for _, seq := range seqs {
if minSeq == -1 || seq < minSeq {
minSeq = seq
}
if maxSeq == -1 || seq > maxSeq {
maxSeq = seq
}
rowCount++
}
}
idx.metadata.MinSeq = minSeq
idx.metadata.MaxSeq = maxSeq
idx.metadata.RowCount = rowCount
}
// load 从磁盘加载索引
// load 从磁盘加载索引(支持 B+Tree 和 JSON 格式)
func (idx *SecondaryIndex) load() error {
// 获取文件大小
stat, err := idx.file.Stat()
@@ -164,6 +127,47 @@ func (idx *SecondaryIndex) load() error {
return nil
}
// 读取文件头,判断格式
headerData := make([]byte, min(int(stat.Size()), IndexHeaderSize))
_, err = idx.file.ReadAt(headerData, 0)
if err != nil {
return err
}
// 检查是否为 B+Tree 格式
if len(headerData) >= 4 {
magic := binary.LittleEndian.Uint32(headerData[0:4])
if magic == IndexMagic {
// B+Tree 格式
return idx.loadBTree()
}
}
// 回退到 JSON 格式(向后兼容)
return idx.loadJSON()
}
// loadBTree 加载 B+Tree 格式的索引
func (idx *SecondaryIndex) loadBTree() error {
reader, err := NewIndexBTreeReader(idx.file)
if err != nil {
return fmt.Errorf("failed to create btree reader: %w", err)
}
idx.btreeReader = reader
idx.metadata = reader.GetMetadata()
idx.useBTree = true
idx.ready = true
return nil
}
// loadJSON 加载 JSON 格式的索引(向后兼容)
func (idx *SecondaryIndex) loadJSON() error {
stat, err := idx.file.Stat()
if err != nil {
return err
}
// 读取文件内容
data := make([]byte, stat.Size())
_, err = idx.file.ReadAt(data, 0)
@@ -171,27 +175,24 @@ func (idx *SecondaryIndex) load() error {
return err
}
// 尝试加载新格式(带元数据)
// 加载 JSON 格式
var indexData struct {
Metadata IndexMetadata `json:"metadata"`
ValueToSeq map[string][]int64 `json:"data"`
}
err = json.Unmarshal(data, &indexData)
if err == nil && indexData.ValueToSeq != nil {
// 新格式
idx.metadata = indexData.Metadata
idx.valueToSeq = indexData.ValueToSeq
} else {
// 旧格式(兼容性)
err = json.Unmarshal(data, &idx.valueToSeq)
if err != nil {
return err
}
// 初始化元数据
idx.updateMetadata()
if err != nil {
return fmt.Errorf("failed to unmarshal index data: %w", err)
}
if indexData.ValueToSeq == nil {
return fmt.Errorf("invalid index data: missing data field")
}
idx.metadata = indexData.Metadata
idx.valueToSeq = indexData.ValueToSeq
idx.useBTree = false
idx.ready = true
return nil
}
@@ -206,6 +207,13 @@ func (idx *SecondaryIndex) Get(value any) ([]int64, error) {
}
key := fmt.Sprintf("%v", value)
// 如果使用 B+Tree从 B+Tree 读取
if idx.useBTree && idx.btreeReader != nil {
return idx.btreeReader.Get(key)
}
// 否则从内存 map 读取
seqs, exists := idx.valueToSeq[key]
if !exists {
return nil, nil
@@ -238,8 +246,8 @@ func (idx *SecondaryIndex) NeedsUpdate(currentMaxSeq int64) bool {
// IncrementalUpdate 增量更新索引
func (idx *SecondaryIndex) IncrementalUpdate(getData func(int64) (map[string]any, error), fromSeq, toSeq int64) error {
idx.mu.Lock()
defer idx.mu.Unlock()
addedCount := int64(0)
// 遍历缺失的 seq 范围
for seq := fromSeq; seq <= toSeq; seq++ {
// 获取数据
@@ -257,39 +265,40 @@ func (idx *SecondaryIndex) IncrementalUpdate(getData func(int64) (map[string]any
// 添加到索引
key := fmt.Sprintf("%v", value)
idx.valueToSeq[key] = append(idx.valueToSeq[key], seq)
// 更新元数据
if idx.metadata.MinSeq == 0 || seq < idx.metadata.MinSeq {
idx.metadata.MinSeq = seq
}
if seq > idx.metadata.MaxSeq {
idx.metadata.MaxSeq = seq
}
addedCount++
}
idx.metadata.RowCount += addedCount
idx.metadata.UpdatedAt = time.Now().UnixNano()
// 释放锁,然后调用 BuildBuild 会重新获取锁)
idx.mu.Unlock()
// 保存更新后的索引
return idx.save()
return idx.Build()
}
// Close 关闭索引
func (idx *SecondaryIndex) Close() error {
// 关闭 B+Tree reader
if idx.btreeReader != nil {
idx.btreeReader.Close()
}
// 关闭文件
if idx.file != nil {
return idx.file.Close()
}
return nil
}
// encodeSeqList 编码 seq 列表
func encodeSeqList(seqs []int64) []byte {
buf := make([]byte, 8*len(seqs))
for i, seq := range seqs {
binary.LittleEndian.PutUint64(buf[i*8:], uint64(seq))
}
return buf
}
// decodeSeqList 解码 seq 列表
func decodeSeqList(data []byte) []int64 {
count := len(data) / 8
seqs := make([]int64, count)
for i := range count {
seqs[i] = int64(binary.LittleEndian.Uint64(data[i*8:]))
}
return seqs
}
// IndexManager 索引管理器
type IndexManager struct {
dir string
@@ -478,9 +487,7 @@ func (m *IndexManager) ListIndexes() []string {
func (m *IndexManager) VerifyAndRepair(currentMaxSeq int64, getData func(int64) (map[string]any, error)) error {
m.mu.RLock()
indexes := make(map[string]*SecondaryIndex)
for k, v := range m.indexes {
indexes[k] = v
}
maps.Copy(indexes, m.indexes)
m.mu.RUnlock()
for field, idx := range indexes {

539
index_btree.go Normal file
View File

@@ -0,0 +1,539 @@
package srdb
import (
"crypto/md5"
"encoding/binary"
"fmt"
"os"
"sort"
"github.com/edsrzf/mmap-go"
)
/*
索引文件存储格式 (B+Tree)
文件结构
Header (256 bytes)
B+Tree 索引区
内部节点 (BTreeNodeSize = 4096 bytes)
NodeType (1 byte): 0=Leaf, 1=Internal
KeyCount (2 bytes): 节点中的 key 数量
Keys[]: int64 数组 (8 bytes each)
Children[]: int64 偏移量数组 (8 bytes each)
叶子节点 (BTreeNodeSize = 4096 bytes)
NodeType (1 byte): 0
KeyCount (2 bytes): 节点中的 key 数量
Keys[]: int64 数组 (8 bytes each)
Offsets[]: int64 数据偏移量 (8 bytes each)
Sizes[]: int32 数据大小 (4 bytes each)
数据块区
Entry 1:
ValueLen (4 bytes): 字段值长度
Value (N bytes): 字段值 (原始字符串)
SeqCount (4 bytes): seq 数量
Seqs (8 bytes each): seq 列表
Entry 2: ...
Entry N: ...
Header 格式 (256 bytes):
Offset | Size | Field | Description
-------|------|----------------|----------------------------------
0 | 4 | Magic | 0x49445842 ("IDXB")
4 | 4 | FormatVersion | 文件格式版本 (1)
8 | 8 | IndexVersion | 索引版本号 (对应 Metadata.Version)
16 | 8 | RootOffset | B+Tree 根节点偏移
24 | 8 | DataStart | 数据块起始位置
32 | 8 | MinSeq | 最小 seq
40 | 8 | MaxSeq | 最大 seq
48 | 8 | RowCount | 总行数
56 | 8 | CreatedAt | 创建时间 (UnixNano)
64 | 8 | UpdatedAt | 更新时间 (UnixNano)
72 | 184 | Reserved | 预留空间
索引条目格式 (变长):
Offset | Size | Field | Description
-------|-------------|------------|----------------------------------
0 | 4 | ValueLen | 字段值长度 (N)
4 | N | Value | 字段值 (原始字符串用于验证哈希冲突)
4+N | 4 | SeqCount | seq 数量 (M)
8+N | M * 8 | Seqs | seq 列表 (int64 数组)
Key 生成规则:
- 使用 MD5 哈希将字符串值转为 int64
- key = MD5(value)[0:8] LittleEndian uint64
- 存储原始 value 用于验证哈希冲突
查询流程:
1. value key (MD5 哈希)
2. B+Tree.Get(key) (offset, size)
3. 读取数据块 mmap[offset:offset+size]
4. 解码并验证原始 value (处理哈希冲突)
5. 返回 seqs 列表
性能特点:
- 加载: O(1) - 只需 mmap 映射
- 查询: O(log n) - B+Tree 查找
- 内存: 几乎为 0 - 零拷贝 mmap
- 支持范围查询 (未来可扩展)
*/
const (
IndexHeaderSize = 256 // 索引文件头大小
IndexMagic = 0x49445842 // "IDXB" - Index B-Tree
IndexVersion = 1 // 文件格式版本
)
// IndexHeader 索引文件头
type IndexHeader struct {
Magic uint32 // 魔数 "IDXB"
FormatVersion uint32 // 文件格式版本号
IndexVersion int64 // 索引版本号(对应 IndexMetadata.Version
RootOffset int64 // B+Tree 根节点偏移
DataStart int64 // 数据块起始位置
MinSeq int64 // 最小 seq
MaxSeq int64 // 最大 seq
RowCount int64 // 总行数
CreatedAt int64 // 创建时间
UpdatedAt int64 // 更新时间
Reserved [184]byte // 预留空间(减少 8 字节给 IndexVersion减少 8 字节调整对齐)
}
// Marshal 序列化 Header
func (h *IndexHeader) Marshal() []byte {
buf := make([]byte, IndexHeaderSize)
binary.LittleEndian.PutUint32(buf[0:4], h.Magic)
binary.LittleEndian.PutUint32(buf[4:8], h.FormatVersion)
binary.LittleEndian.PutUint64(buf[8:16], uint64(h.IndexVersion))
binary.LittleEndian.PutUint64(buf[16:24], uint64(h.RootOffset))
binary.LittleEndian.PutUint64(buf[24:32], uint64(h.DataStart))
binary.LittleEndian.PutUint64(buf[32:40], uint64(h.MinSeq))
binary.LittleEndian.PutUint64(buf[40:48], uint64(h.MaxSeq))
binary.LittleEndian.PutUint64(buf[48:56], uint64(h.RowCount))
binary.LittleEndian.PutUint64(buf[56:64], uint64(h.CreatedAt))
binary.LittleEndian.PutUint64(buf[64:72], uint64(h.UpdatedAt))
copy(buf[72:], h.Reserved[:])
return buf
}
// UnmarshalIndexHeader 反序列化 Header
func UnmarshalIndexHeader(data []byte) *IndexHeader {
if len(data) < IndexHeaderSize {
return nil
}
h := &IndexHeader{}
h.Magic = binary.LittleEndian.Uint32(data[0:4])
h.FormatVersion = binary.LittleEndian.Uint32(data[4:8])
h.IndexVersion = int64(binary.LittleEndian.Uint64(data[8:16]))
h.RootOffset = int64(binary.LittleEndian.Uint64(data[16:24]))
h.DataStart = int64(binary.LittleEndian.Uint64(data[24:32]))
h.MinSeq = int64(binary.LittleEndian.Uint64(data[32:40]))
h.MaxSeq = int64(binary.LittleEndian.Uint64(data[40:48]))
h.RowCount = int64(binary.LittleEndian.Uint64(data[48:56]))
h.CreatedAt = int64(binary.LittleEndian.Uint64(data[56:64]))
h.UpdatedAt = int64(binary.LittleEndian.Uint64(data[64:72]))
copy(h.Reserved[:], data[72:IndexHeaderSize])
return h
}
// valueToKey 将字段值转换为 B+Tree key使用哈希
//
// 原理:
// - 字符串无法直接用作 B+Tree key (需要 int64)
// - 使用 MD5 哈希将任意字符串映射为 int64
// - 取 MD5 的前 8 字节作为 key
//
// 哈希冲突处理:
// - 存储原始 value 在数据块中
// - 查询时验证原始 value 是否匹配
// - 冲突时返回 nil极低概率
//
// 示例:
// "Alice" → MD5 → 0x3bc15c8aae3e4124... → key = 0x3bc15c8aae3e4124
func valueToKey(value string) int64 {
// 使用 MD5 的前 8 字节作为 int64 key
hash := md5.Sum([]byte(value))
return int64(binary.LittleEndian.Uint64(hash[:8]))
}
// encodeIndexEntry 将索引条目编码为二进制格式(零拷贝友好)
//
// 格式:[ValueLen(4B)][Value(N bytes)][SeqCount(4B)][Seq1(8B)][Seq2(8B)]...
//
// 示例:
// value = "Alice", seqs = [1, 5, 10]
// 编码结果:
// [0x05, 0x00, 0x00, 0x00] // ValueLen = 5
// [0x41, 0x6c, 0x69, 0x63, 0x65] // "Alice"
// [0x03, 0x00, 0x00, 0x00] // SeqCount = 3
// [0x01, 0x00, 0x00, 0x00, ...] // Seq1 = 1
// [0x05, 0x00, 0x00, 0x00, ...] // Seq2 = 5
// [0x0a, 0x00, 0x00, 0x00, ...] // Seq3 = 10
//
// 总大小4 + 5 + 4 + 3*8 = 37 bytes
func encodeIndexEntry(value string, seqs []int64) []byte {
valueBytes := []byte(value)
size := 4 + len(valueBytes) + 4 + len(seqs)*8
buf := make([]byte, size)
// 写入 ValueLen
binary.LittleEndian.PutUint32(buf[0:4], uint32(len(valueBytes)))
// 写入 Value
copy(buf[4:], valueBytes)
// 写入 SeqCount
offset := 4 + len(valueBytes)
binary.LittleEndian.PutUint32(buf[offset:offset+4], uint32(len(seqs)))
// 写入 Seqs
offset += 4
for i, seq := range seqs {
binary.LittleEndian.PutUint64(buf[offset+i*8:offset+(i+1)*8], uint64(seq))
}
return buf
}
// decodeIndexEntry 从二进制格式解码索引条目(零拷贝)
//
// 参数:
// data: 编码后的二进制数据(来自 mmap零拷贝
//
// 返回:
// value: 原始字段值(用于验证哈希冲突)
// seqs: seq 列表
// err: 解码错误
//
// 零拷贝优化:
// - 直接从 mmap 数据中读取,不复制
// - string(data[4:4+valueLen]) 会复制,但无法避免
// - seqs 数组需要分配,但只复制指针大小的数据
func decodeIndexEntry(data []byte) (value string, seqs []int64, err error) {
if len(data) < 8 {
return "", nil, fmt.Errorf("data too short: %d bytes", len(data))
}
// 读取 ValueLen
valueLen := binary.LittleEndian.Uint32(data[0:4])
if len(data) < int(4+valueLen+4) {
return "", nil, fmt.Errorf("data too short for value: expected %d, got %d", 4+valueLen+4, len(data))
}
// 读取 Value
value = string(data[4 : 4+valueLen])
// 读取 SeqCount
offset := 4 + int(valueLen)
seqCount := binary.LittleEndian.Uint32(data[offset : offset+4])
// 验证数据长度
expectedSize := offset + 4 + int(seqCount)*8
if len(data) < expectedSize {
return "", nil, fmt.Errorf("data too short for seqs: expected %d, got %d", expectedSize, len(data))
}
// 读取 Seqs
seqs = make([]int64, seqCount)
offset += 4
for i := 0; i < int(seqCount); i++ {
seqs[i] = int64(binary.LittleEndian.Uint64(data[offset+i*8 : offset+(i+1)*8]))
}
return value, seqs, nil
}
// IndexBTreeWriter 使用 B+Tree 写入索引
//
// 写入流程:
// 1. Add(): 收集所有 (value, seqs) 到内存
// 2. Build():
// a. 计算所有 value 的 key (MD5 哈希)
// b. 按 key 排序B+Tree 要求有序)
// c. 编码所有条目为二进制格式
// d. 构建 B+Tree 索引
// e. 写入 Header + B+Tree + 数据块
//
// 文件布局:
// [Header] → [B+Tree] → [Data Blocks]
type IndexBTreeWriter struct {
file *os.File
header IndexHeader
entries map[string][]int64 // value -> seqs
dataOffset int64
}
// NewIndexBTreeWriter 创建索引写入器
func NewIndexBTreeWriter(file *os.File, metadata IndexMetadata) *IndexBTreeWriter {
return &IndexBTreeWriter{
file: file,
header: IndexHeader{
Magic: IndexMagic,
FormatVersion: IndexVersion,
IndexVersion: metadata.Version,
MinSeq: metadata.MinSeq,
MaxSeq: metadata.MaxSeq,
RowCount: metadata.RowCount,
CreatedAt: metadata.CreatedAt,
UpdatedAt: metadata.UpdatedAt,
},
entries: make(map[string][]int64),
dataOffset: IndexHeaderSize,
}
}
// Add 添加索引条目
func (w *IndexBTreeWriter) Add(value string, seqs []int64) {
w.entries[value] = seqs
}
// Build 构建并写入索引文件
func (w *IndexBTreeWriter) Build() error {
// 1. 计算所有 key 并按 key 排序(确保 B+Tree 构建有序)
type valueKey struct {
value string
key int64
}
var valueKeys []valueKey
for value := range w.entries {
valueKeys = append(valueKeys, valueKey{
value: value,
key: valueToKey(value),
})
}
// 按 key 排序(而不是按字符串)
sort.Slice(valueKeys, func(i, j int) bool {
return valueKeys[i].key < valueKeys[j].key
})
// 2. 先写入数据块并记录位置
type keyOffset struct {
key int64
offset int64
size int32
}
var keyOffsets []keyOffset
// 预留 Header 空间
currentOffset := int64(IndexHeaderSize)
// 构建数据块(使用二进制格式,无压缩)
var dataBlocks [][]byte
for _, vk := range valueKeys {
value := vk.value
seqs := w.entries[value]
// 编码为二进制格式
binaryData := encodeIndexEntry(value, seqs)
// 记录 key 和数据位置key 已经在 vk 中)
key := vk.key
dataBlocks = append(dataBlocks, binaryData)
// 暂时不知道确切的 offset先占位
keyOffsets = append(keyOffsets, keyOffset{
key: key,
offset: 0, // 稍后填充
size: int32(len(binaryData)),
})
currentOffset += int64(len(binaryData))
}
// 3. 计算 B+Tree 起始位置(紧接 Header
btreeStart := int64(IndexHeaderSize)
// 估算 B+Tree 大小(每个叶子节点最多 BTreeOrder 个条目)
numEntries := len(keyOffsets)
numLeafNodes := (numEntries + BTreeOrder - 1) / BTreeOrder
// 计算所有层级的节点总数
totalNodes := numLeafNodes
nodesAtCurrentLevel := numLeafNodes
for nodesAtCurrentLevel > 1 {
nodesAtCurrentLevel = (nodesAtCurrentLevel + BTreeOrder - 1) / BTreeOrder
totalNodes += nodesAtCurrentLevel
}
btreeSize := int64(totalNodes * BTreeNodeSize)
dataStart := btreeStart + btreeSize
// 4. 更新数据块的实际偏移量
currentDataOffset := dataStart
for i := range keyOffsets {
keyOffsets[i].offset = currentDataOffset
currentDataOffset += int64(keyOffsets[i].size)
}
// 5. 写入 Header预留位置
w.header.DataStart = dataStart
w.file.WriteAt(w.header.Marshal(), 0)
// 6. 构建 B+Tree
builder := NewBTreeBuilder(w.file, btreeStart)
for _, ko := range keyOffsets {
err := builder.Add(ko.key, ko.offset, ko.size)
if err != nil {
return fmt.Errorf("failed to add to btree: %w", err)
}
}
rootOffset, err := builder.Build()
if err != nil {
return fmt.Errorf("failed to build btree: %w", err)
}
// 7. 写入数据块
currentDataOffset = dataStart
for _, data := range dataBlocks {
_, err := w.file.WriteAt(data, currentDataOffset)
if err != nil {
return fmt.Errorf("failed to write data block: %w", err)
}
currentDataOffset += int64(len(data))
}
// 8. 更新 Header写入正确的 RootOffset
w.header.RootOffset = rootOffset
_, err = w.file.WriteAt(w.header.Marshal(), 0)
if err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
// 9. Sync 到磁盘
return w.file.Sync()
}
// IndexBTreeReader 使用 B+Tree 读取索引
//
// 读取流程:
// 1. mmap 映射整个文件(零拷贝)
// 2. 读取 Header
// 3. 创建 BTreeReader (指向 RootOffset)
// 4. Get(value):
// a. value → key (MD5 哈希)
// b. BTree.Get(key) → (offset, size)
// c. 读取 mmap[offset:offset+size](零拷贝)
// d. 解码并验证原始 value
// e. 返回 seqs
//
// 性能优化:
// - mmap 零拷贝:不需要加载整个文件到内存
// - B+Tree 索引O(log n) 查询
// - 按需读取:只读取需要的数据块
type IndexBTreeReader struct {
file *os.File
mmap mmap.MMap
header IndexHeader
btree *BTreeReader
}
// NewIndexBTreeReader 创建索引读取器
func NewIndexBTreeReader(file *os.File) (*IndexBTreeReader, error) {
// 读取 Header
headerData := make([]byte, IndexHeaderSize)
_, err := file.ReadAt(headerData, 0)
if err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
header := UnmarshalIndexHeader(headerData)
if header == nil || header.Magic != IndexMagic {
return nil, fmt.Errorf("invalid index file: bad magic")
}
// mmap 整个文件
mmapData, err := mmap.Map(file, mmap.RDONLY, 0)
if err != nil {
return nil, fmt.Errorf("failed to mmap index file: %w", err)
}
// 创建 B+Tree Reader
btree := NewBTreeReader(mmapData, header.RootOffset)
return &IndexBTreeReader{
file: file,
mmap: mmapData,
header: *header,
btree: btree,
}, nil
}
// Get 查询字段值对应的 seq 列表(零拷贝)
//
// 参数:
// value: 字段值(例如 "Alice"
//
// 返回:
// seqs: seq 列表(例如 [1, 5, 10]
// err: 查询错误
//
// 查询流程:
// 1. value → key (MD5 哈希)
// 2. B+Tree.Get(key) → (offset, size)
// 3. 读取 mmap[offset:offset+size](零拷贝)
// 4. 解码并验证原始 value处理哈希冲突
// 5. 返回 seqs
//
// 哈希冲突处理:
// - 如果 storedValue != value说明发生哈希冲突
// - 返回 nil表示未找到
// - 冲突概率极低MD5 64位空间
func (r *IndexBTreeReader) Get(value string) ([]int64, error) {
// 计算 key
key := valueToKey(value)
// 在 B+Tree 中查找
dataOffset, dataSize, found := r.btree.Get(key)
if !found {
return nil, nil
}
// 读取数据块(零拷贝)
if dataOffset+int64(dataSize) > int64(len(r.mmap)) {
return nil, fmt.Errorf("data offset out of range: offset=%d, size=%d, mmap_len=%d", dataOffset, dataSize, len(r.mmap))
}
binaryData := r.mmap[dataOffset : dataOffset+int64(dataSize)]
// 解码二进制数据
storedValue, seqs, err := decodeIndexEntry(binaryData)
if err != nil {
return nil, fmt.Errorf("failed to decode entry: %w", err)
}
// 验证原始值(处理哈希冲突)
if storedValue != value {
// 哈希冲突,返回空
return nil, nil
}
return seqs, nil
}
// GetMetadata 获取元数据
func (r *IndexBTreeReader) GetMetadata() IndexMetadata {
return IndexMetadata{
Version: r.header.IndexVersion,
MinSeq: r.header.MinSeq,
MaxSeq: r.header.MaxSeq,
RowCount: r.header.RowCount,
CreatedAt: r.header.CreatedAt,
UpdatedAt: r.header.UpdatedAt,
}
}
// Close 关闭读取器
func (r *IndexBTreeReader) Close() error {
if r.mmap != nil {
r.mmap.Unmap()
}
return nil
}

628
index_btree_test.go Normal file
View File

@@ -0,0 +1,628 @@
package srdb
import (
"fmt"
"os"
"testing"
"time"
)
func TestIndexBTreeBasic(t *testing.T) {
// 创建临时目录
tmpDir := t.TempDir()
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64},
{Name: "name", Type: FieldTypeString},
{Name: "city", Type: FieldTypeString},
})
// 创建索引管理器
mgr := NewIndexManager(tmpDir, schema)
defer mgr.Close()
// 创建索引
err := mgr.CreateIndex("city")
if err != nil {
t.Fatalf("Failed to create index: %v", err)
}
// 添加测试数据
testData := []struct {
city string
seq int64
}{
{"Beijing", 1},
{"Shanghai", 2},
{"Beijing", 3},
{"Shenzhen", 4},
{"Shanghai", 5},
{"Beijing", 6},
}
for _, td := range testData {
data := map[string]any{
"id": td.seq,
"name": "user_" + string(rune(td.seq)),
"city": td.city,
}
err := mgr.AddToIndexes(data, td.seq)
if err != nil {
t.Fatalf("Failed to add to index: %v", err)
}
}
// 构建索引
err = mgr.BuildAll()
if err != nil {
t.Fatalf("Failed to build index: %v", err)
}
// 关闭并重新打开,测试持久化
mgr.Close()
// 重新打开
mgr2 := NewIndexManager(tmpDir, schema)
defer mgr2.Close()
// 查询索引
idx, exists := mgr2.GetIndex("city")
if !exists {
t.Fatal("Index not found after reload")
}
// 验证索引使用 B+Tree
if !idx.useBTree {
t.Error("Index should be using B+Tree format")
}
// 验证查询结果
testCases := []struct {
city string
expectedSeqs []int64
}{
{"Beijing", []int64{1, 3, 6}},
{"Shanghai", []int64{2, 5}},
{"Shenzhen", []int64{4}},
{"Unknown", nil},
}
for _, tc := range testCases {
seqs, err := idx.Get(tc.city)
if err != nil {
t.Errorf("Failed to query index for %s: %v", tc.city, err)
continue
}
if len(seqs) != len(tc.expectedSeqs) {
t.Errorf("City %s: expected %d seqs, got %d", tc.city, len(tc.expectedSeqs), len(seqs))
continue
}
// 验证 seq 值
seqMap := make(map[int64]bool)
for _, seq := range seqs {
seqMap[seq] = true
}
for _, expectedSeq := range tc.expectedSeqs {
if !seqMap[expectedSeq] {
t.Errorf("City %s: missing expected seq %d", tc.city, expectedSeq)
}
}
}
// 验证元数据
metadata := idx.GetMetadata()
if metadata.MinSeq != 1 || metadata.MaxSeq != 6 || metadata.RowCount != 6 {
t.Errorf("Invalid metadata: MinSeq=%d, MaxSeq=%d, RowCount=%d",
metadata.MinSeq, metadata.MaxSeq, metadata.RowCount)
}
}
func TestIndexBTreeLargeDataset(t *testing.T) {
// 创建临时目录
tmpDir := t.TempDir()
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64},
{Name: "category", Type: FieldTypeString},
})
// 创建索引管理器
mgr := NewIndexManager(tmpDir, schema)
defer mgr.Close()
// 创建索引
err := mgr.CreateIndex("category")
if err != nil {
t.Fatalf("Failed to create index: %v", err)
}
// 添加大量测试数据
numRecords := 10000
numCategories := 100
for i := 1; i <= numRecords; i++ {
category := "cat_" + string(rune('A'+(i%numCategories)))
data := map[string]any{
"id": int64(i),
"category": category,
}
err := mgr.AddToIndexes(data, int64(i))
if err != nil {
t.Fatalf("Failed to add to index: %v", err)
}
}
// 构建索引
startBuild := time.Now()
err = mgr.BuildAll()
if err != nil {
t.Fatalf("Failed to build index: %v", err)
}
buildTime := time.Since(startBuild)
t.Logf("Built index with %d records in %v", numRecords, buildTime)
// 获取索引文件大小
idx, _ := mgr.GetIndex("category")
stat, _ := idx.file.Stat()
t.Logf("Index file size: %d bytes", stat.Size())
// 关闭并重新打开
mgr.Close()
// 重新打开
mgr2 := NewIndexManager(tmpDir, schema)
defer mgr2.Close()
idx2, exists := mgr2.GetIndex("category")
if !exists {
t.Fatal("Index not found after reload")
}
// 验证索引使用 B+Tree
if !idx2.useBTree {
t.Error("Index should be using B+Tree format")
}
// 随机查询测试
startQuery := time.Now()
for i := 0; i < 100; i++ {
category := "cat_" + string(rune('A'+(i%numCategories)))
seqs, err := idx2.Get(category)
if err != nil {
t.Errorf("Failed to query index for %s: %v", category, err)
}
// 验证返回的 seq 数量
expectedCount := numRecords / numCategories
if len(seqs) != expectedCount {
t.Errorf("Category %s: expected %d seqs, got %d", category, expectedCount, len(seqs))
}
}
queryTime := time.Since(startQuery)
t.Logf("Queried 100 categories in %v (avg: %v per query)", queryTime, queryTime/100)
}
func TestIndexBTreeBackwardCompatibility(t *testing.T) {
// 创建临时目录
tmpDir := t.TempDir()
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64},
{Name: "status", Type: FieldTypeString},
})
// 1. 创建索引管理器并用旧方式(通过先禁用新格式)创建索引
mgr := NewIndexManager(tmpDir, schema)
// 创建索引
err := mgr.CreateIndex("status")
if err != nil {
t.Fatalf("Failed to create index: %v", err)
}
// 添加<E6B7BB><E58AA0><EFBFBD>
testData := map[string][]int64{
"active": {1, 3, 5},
"inactive": {2, 4},
}
for status, seqs := range testData {
for _, seq := range seqs {
data := map[string]any{
"id": seq,
"status": status,
}
err := mgr.AddToIndexes(data, seq)
if err != nil {
t.Fatalf("Failed to add to index: %v", err)
}
}
}
// 构建索引(会使用新的 B+Tree 格式)
err = mgr.BuildAll()
if err != nil {
t.Fatalf("Failed to build index: %v", err)
}
// 关闭
mgr.Close()
// 2. 重新加载并验证
mgr2 := NewIndexManager(tmpDir, schema)
defer mgr2.Close()
idx, exists := mgr2.GetIndex("status")
if !exists {
t.Fatal("Failed to load index")
}
// 应该使用 B+Tree 格式
if !idx.useBTree {
t.Error("Index should be using B+Tree format")
}
// 验证查询结果
seqs, err := idx.Get("active")
if err != nil || len(seqs) != 3 {
t.Errorf("Failed to query index: err=%v, seqs=%v", err, seqs)
}
seqs, err = idx.Get("inactive")
if err != nil || len(seqs) != 2 {
t.Errorf("Failed to query index: err=%v, seqs=%v", err, seqs)
}
t.Log("Successfully loaded and queried B+Tree format index")
}
func TestIndexBTreeIncrementalUpdate(t *testing.T) {
// 创建临时目录
tmpDir := t.TempDir()
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64},
{Name: "tag", Type: FieldTypeString},
})
// 创建索引管理器
mgr := NewIndexManager(tmpDir, schema)
defer mgr.Close()
// 创建索引
err := mgr.CreateIndex("tag")
if err != nil {
t.Fatalf("Failed to create index: %v", err)
}
// 添加初始数据
for i := 1; i <= 100; i++ {
tag := "tag_" + string(rune('A'+(i%10)))
data := map[string]any{
"id": int64(i),
"tag": tag,
}
err := mgr.AddToIndexes(data, int64(i))
if err != nil {
t.Fatalf("Failed to add to index: %v", err)
}
}
// 构建索引
err = mgr.BuildAll()
if err != nil {
t.Fatalf("Failed to build index: %v", err)
}
// 获取索引
idx, _ := mgr.GetIndex("tag")
// 验证初始元数据
metadata := idx.GetMetadata()
if metadata.MaxSeq != 100 {
t.Errorf("Expected MaxSeq=100, got %d", metadata.MaxSeq)
}
// 增量更新:添加新数据
getData := func(seq int64) (map[string]any, error) {
tag := "tag_" + string(rune('A'+(int(seq)%10)))
return map[string]any{
"id": seq,
"tag": tag,
}, nil
}
err = idx.IncrementalUpdate(getData, 101, 200)
if err != nil {
t.Fatalf("Failed to incremental update: %v", err)
}
// 验证更新后的元数据
metadata = idx.GetMetadata()
if metadata.MaxSeq != 200 {
t.Errorf("Expected MaxSeq=200 after update, got %d", metadata.MaxSeq)
}
if metadata.RowCount != 200 {
t.Errorf("Expected RowCount=200 after update, got %d", metadata.RowCount)
}
// 验证可以查询到新数据
seqs, err := idx.Get("tag_A")
if err != nil {
t.Fatalf("Failed to query index: %v", err)
}
// tag_A 应该包含 seq 1, 11, 21, ..., 191 (20个)
if len(seqs) != 20 {
t.Errorf("Expected 20 seqs for tag_A, got %d", len(seqs))
}
t.Logf("Incremental update successful: %d records indexed", metadata.RowCount)
}
// TestIndexBTreeWriter 测试 B+Tree 写入器
func TestIndexBTreeWriter(t *testing.T) {
tmpDir := t.TempDir()
filePath := tmpDir + "/test_idx.sst"
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
t.Fatalf("Failed to create file: %v", err)
}
defer file.Close()
// 创建写入器
metadata := IndexMetadata{
MinSeq: 1,
MaxSeq: 10,
RowCount: 10,
CreatedAt: time.Now().UnixNano(),
UpdatedAt: time.Now().UnixNano(),
}
writer := NewIndexBTreeWriter(file, metadata)
// 添加测试数据
testData := map[string][]int64{
"apple": {1, 2, 3},
"banana": {4, 5},
"cherry": {6, 7, 8, 9},
"date": {10},
}
for value, seqs := range testData {
writer.Add(value, seqs)
}
// 构建索引
err = writer.Build()
if err != nil {
t.Fatalf("Failed to build index: %v", err)
}
t.Log("B+Tree index written successfully")
// 关闭并重新打开文件
file.Close()
// 读取索引
file, err = os.Open(filePath)
if err != nil {
t.Fatalf("Failed to reopen file: %v", err)
}
reader, err := NewIndexBTreeReader(file)
if err != nil {
t.Fatalf("Failed to create reader: %v", err)
}
defer reader.Close()
// 验证读取
for value, expectedSeqs := range testData {
seqs, err := reader.Get(value)
if err != nil {
t.Errorf("Failed to get %s: %v", value, err)
continue
}
if len(seqs) != len(expectedSeqs) {
t.Errorf("%s: expected %d seqs, got %d", value, len(expectedSeqs), len(seqs))
t.Logf(" Expected: %v", expectedSeqs)
t.Logf(" Got: %v", seqs)
} else {
// 验证每个 seq
seqMap := make(map[int64]bool)
for _, seq := range seqs {
seqMap[seq] = true
}
for _, expectedSeq := range expectedSeqs {
if !seqMap[expectedSeq] {
t.Errorf("%s: missing seq %d", value, expectedSeq)
}
}
}
}
// 测试不存在的值
seqs, err := reader.Get("unknown")
if err != nil {
t.Errorf("Failed to get unknown: %v", err)
}
if seqs != nil {
t.Errorf("Expected nil for unknown value, got %v", seqs)
}
t.Log("All B+Tree reads successful")
}
// TestValueToKey 测试哈希函数
func TestValueToKey(t *testing.T) {
testCases := []string{
"apple",
"banana",
"cherry",
"Beijing",
"Shanghai",
"Shenzhen",
}
keyMap := make(map[int64]string)
for _, value := range testCases {
key := valueToKey(value)
t.Logf("valueToKey(%s) = %d", value, key)
// 检查哈希冲突
if existingValue, exists := keyMap[key]; exists {
t.Errorf("Hash collision: %s and %s both hash to %d", value, existingValue, key)
}
keyMap[key] = value
// 验证哈希的一致性
key2 := valueToKey(value)
if key != key2 {
t.Errorf("Hash inconsistency for %s: %d != %d", value, key, key2)
}
}
}
// TestIndexBTreeDataTypes 测试不同数据类型
func TestIndexBTreeDataTypes(t *testing.T) {
tmpDir := t.TempDir()
filePath := tmpDir + "/test_types.sst"
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
t.Fatalf("Failed to create file: %v", err)
}
defer file.Close()
metadata := IndexMetadata{
MinSeq: 1,
MaxSeq: 5,
RowCount: 5,
CreatedAt: time.Now().UnixNano(),
UpdatedAt: time.Now().UnixNano(),
}
writer := NewIndexBTreeWriter(file, metadata)
// 测试不同类型的值(都转换为字符串)
testData := map[string][]int64{
"123": {1}, // 数字字符串
"true": {2}, // 布尔字符串
"hello": {3}, // 普通字符串
"世界": {4}, // 中文
"": {5}, // 空字符串
}
for value, seqs := range testData {
writer.Add(value, seqs)
}
err = writer.Build()
if err != nil {
t.Fatalf("Failed to build: %v", err)
}
file.Close()
// 重新读取
file, err = os.Open(filePath)
if err != nil {
t.Fatalf("Failed to reopen: %v", err)
}
reader, err := NewIndexBTreeReader(file)
if err != nil {
t.Fatalf("Failed to create reader: %v", err)
}
defer reader.Close()
// 验证
for value, expectedSeqs := range testData {
seqs, err := reader.Get(value)
if err != nil {
t.Errorf("Failed to get '%s': %v", value, err)
continue
}
if len(seqs) != len(expectedSeqs) {
t.Errorf("'%s': expected %d seqs, got %d", value, len(expectedSeqs), len(seqs))
}
}
t.Log("All data types tested successfully")
}
// 测试大数据量
func TestIndexBTreeLargeData(t *testing.T) {
tmpDir := t.TempDir()
filePath := tmpDir + "/test_large.sst"
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
t.Fatalf("Failed to create file: %v", err)
}
defer file.Close()
metadata := IndexMetadata{
MinSeq: 1,
MaxSeq: 1000,
RowCount: 1000,
CreatedAt: time.Now().UnixNano(),
UpdatedAt: time.Now().UnixNano(),
}
writer := NewIndexBTreeWriter(file, metadata)
// 添加 1000 个不同的值
for i := range 1000 {
value := fmt.Sprintf("value_%d", i)
seqs := []int64{int64(i + 1)}
writer.Add(value, seqs)
}
err = writer.Build()
if err != nil {
t.Fatalf("Failed to build: %v", err)
}
fileInfo, _ := file.Stat()
t.Logf("Index file size: %d bytes", fileInfo.Size())
file.Close()
// 重新读取
file, err = os.Open(filePath)
if err != nil {
t.Fatalf("Failed to reopen: %v", err)
}
reader, err := NewIndexBTreeReader(file)
if err != nil {
t.Fatalf("Failed to create reader: %v", err)
}
defer reader.Close()
// 随机验证 100 个值
for i := range 100 {
value := fmt.Sprintf("value_%d", i*10)
seqs, err := reader.Get(value)
if err != nil {
t.Errorf("Failed to get %s: %v", value, err)
continue
}
if len(seqs) != 1 || seqs[0] != int64(i*10+1) {
t.Errorf("%s: expected [%d], got %v", value, i*10+1, seqs)
}
}
t.Log("Large data test successful")
}

View File

@@ -31,18 +31,13 @@ func (m *mapFieldset) Get(key string) (Field, any, error) {
return Field{}, nil, fmt.Errorf("field %s not found", key)
}
// 如果有 schema,返回字段定义
if m.schema != nil {
field, err := m.schema.GetField(key)
if err != nil {
// 字段在 schema 中不存在,返回默认 Field
return Field{Name: key}, value, nil
}
return *field, value, nil
// 从 Schema 获取字段定义
field, err := m.schema.GetField(key)
if err != nil {
// 字段在 schema 中不存在,返回默认 Field
return Field{Name: key}, value, nil
}
// 没有 schema返回默认 Field
return Field{Name: key}, value, nil
return *field, value, nil
}
type Expr interface {
@@ -364,12 +359,12 @@ func Or(exprs ...Expr) Expr {
type QueryBuilder struct {
conds []Expr
fields []string // 要选择的字段nil 表示选择所有字段
engine *Engine
table *Table
}
func newQueryBuilder(engine *Engine) *QueryBuilder {
func newQueryBuilder(table *Table) *QueryBuilder {
return &QueryBuilder{
engine: engine,
table: table,
}
}
@@ -384,7 +379,7 @@ func (qb *QueryBuilder) Match(data map[string]any) bool {
return true
}
fs := newMapFieldset(data, qb.engine.schema)
fs := newMapFieldset(data, qb.table.schema)
for _, cond := range qb.conds {
if !cond.Match(fs) {
return false
@@ -477,15 +472,15 @@ func (qb *QueryBuilder) NotNull(field string) *QueryBuilder {
// Rows 返回所有匹配的数据(游标模式 - 惰性加载)
func (qb *QueryBuilder) Rows() (*Rows, error) {
if qb.engine == nil {
return nil, fmt.Errorf("engine is nil")
if qb.table == nil {
return nil, fmt.Errorf("table is nil")
}
rows := &Rows{
schema: qb.engine.schema,
schema: qb.table.schema,
fields: qb.fields,
qb: qb,
engine: qb.engine,
table: qb.table,
visited: make(map[int64]bool),
}
@@ -495,7 +490,7 @@ func (qb *QueryBuilder) Rows() (*Rows, error) {
var allKeys []int64
// 1. 从 Active MemTable 读取数据
activeMemTable := qb.engine.memtableManager.GetActive()
activeMemTable := qb.table.memtableManager.GetActive()
if activeMemTable != nil {
activeKeys := activeMemTable.Keys()
for _, key := range activeKeys {
@@ -510,7 +505,7 @@ func (qb *QueryBuilder) Rows() (*Rows, error) {
}
// 2. 从所有 Immutable MemTables 读取数据
immutables := qb.engine.memtableManager.GetImmutables()
immutables := qb.table.memtableManager.GetImmutables()
for _, imm := range immutables {
immKeys := imm.MemTable.Keys()
for _, key := range immKeys {
@@ -530,13 +525,13 @@ func (qb *QueryBuilder) Rows() (*Rows, error) {
}
// 3. 收集所有 SST 文件的 keys
sstReaders := qb.engine.sstManager.GetReaders()
sstReaders := qb.table.sstManager.GetReaders()
for _, reader := range sstReaders {
// 获取文件中实际存在的 key 列表(已在 GetAllKeys 中排序)
keys := reader.GetAllKeys()
// 记录所有 keys实际数据稍后统一从 engine 读取)
// 记录所有 keys实际数据稍后统一从 table 读取)
for _, key := range keys {
// 如果 key 已存在(来自更新的数据源),跳过
if _, exists := keyToRow[key]; !exists {
@@ -561,15 +556,15 @@ func (qb *QueryBuilder) Rows() (*Rows, error) {
// 排序
slices.Sort(uniqueKeys)
// 统一从 engine 读取所有数据(避免 compaction 导致的文件删除)
// 统一从 table 读取所有数据(避免 compaction 导致的文件删除)
rows.cachedRows = make([]*SSTableRow, 0, len(uniqueKeys))
for _, seq := range uniqueKeys {
// 如果已经从 MemTable 读取,直接使用
row := keyToRow[seq]
if row == nil {
// 从 engine 读取(会搜索 MemTable + 所有 SST包括 compaction 后的新文件)
// 从 table 读取(会搜索 MemTable + 所有 SST包括 compaction 后的新文件)
var err error
row, err = qb.engine.Get(seq)
row, err = qb.table.Get(seq)
if err != nil {
// 数据不存在(理论上不应该发生,因为 key 来自索引)
continue
@@ -689,7 +684,7 @@ type Rows struct {
schema *Schema
fields []string // 要选择的字段nil 表示选择所有字段
qb *QueryBuilder
engine *Engine
table *Table
// 迭代状态
currentRow *Row
@@ -762,7 +757,7 @@ func (r *Rows) next() bool {
if r.memIterator != nil {
if seq, ok := r.memIterator.next(); ok {
if !r.visited[seq] {
row, err := r.engine.Get(seq)
row, err := r.table.Get(seq)
if err == nil && r.qb.Match(row.Data) {
r.visited[seq] = true
r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row}
@@ -780,7 +775,7 @@ func (r *Rows) next() bool {
if r.immutableIterator != nil {
if seq, ok := r.immutableIterator.next(); ok {
if !r.visited[seq] {
row, err := r.engine.Get(seq)
row, err := r.table.Get(seq)
if err == nil && r.qb.Match(row.Data) {
r.visited[seq] = true
r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row}
@@ -796,8 +791,8 @@ func (r *Rows) next() bool {
}
// 检查是否有更多 Immutable MemTables
if r.immutableIterator == nil && r.immutableIndex < len(r.engine.memtableManager.GetImmutables()) {
immutables := r.engine.memtableManager.GetImmutables()
if r.immutableIterator == nil && r.immutableIndex < len(r.table.memtableManager.GetImmutables()) {
immutables := r.table.memtableManager.GetImmutables()
if r.immutableIndex < len(immutables) {
r.immutableIterator = newMemtableIterator(immutables[r.immutableIndex].MemTable.Keys())
continue
@@ -813,7 +808,7 @@ func (r *Rows) next() bool {
sstReader.index++
if !r.visited[seq] {
row, err := r.engine.Get(seq)
row, err := r.table.Get(seq)
if err == nil && r.qb.Match(row.Data) {
r.visited[seq] = true
r.currentRow = &Row{schema: r.schema, fields: r.fields, inner: row}

View File

@@ -63,7 +63,7 @@ func (s *Schema) GetField(name string) (*Field, error) {
return &s.Fields[i], nil
}
}
return nil, fmt.Errorf("field %s not found", name)
return nil, NewErrorf(ErrCodeFieldNotFound, "field %s not found", name)
}
// GetIndexedFields 获取所有需要索引的字段
@@ -140,7 +140,7 @@ func (s *Schema) ExtractIndexValue(field string, data map[string]any) (any, erro
value, exists := data[field]
if !exists {
return nil, fmt.Errorf("field %s not found in data", field)
return nil, NewErrorf(ErrCodeFieldNotFound, "field %s not found in data", field)
}
// 类型转换

View File

@@ -3,7 +3,6 @@ package srdb
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"os"
"path/filepath"
@@ -12,7 +11,6 @@ import (
"sync"
"github.com/edsrzf/mmap-go"
"github.com/golang/snappy"
)
const (
@@ -22,13 +20,9 @@ const (
SSTableHeaderSize = 256 // 文件头大小
SSTableBlockSize = 64 * 1024 // 数据块大小 (64 KB)
// 压缩类型
SSTableCompressionNone = 0
SSTableCompressionSnappy = 1
// 二进制编码格式:
// [Magic: 4 bytes][Seq: 8 bytes][Time: 8 bytes][DataLen: 4 bytes][Data: variable]
SSTableRowMagic = 0x524F5733 // "ROW3"
SSTableRowMagic = 0x524F5731 // "ROW1"
)
// SSTableHeader SST 文件头 (256 bytes)
@@ -36,7 +30,7 @@ type SSTableHeader struct {
// 基础信息 (32 bytes)
Magic uint32 // Magic Number: 0x53535433
Version uint32 // 版本号
Compression uint8 // 压缩类型
Compression uint8 // 压缩类型(保留字段用于向后兼容)
Reserved1 [3]byte
Flags uint32 // 标志位
Reserved2 [16]byte
@@ -175,14 +169,9 @@ func encodeSSTableRowBinary(row *SSTableRow, schema *Schema) ([]byte, error) {
return nil, err
}
// 如果没有 Schema,回退到 JSON 编码(只编码 Data 部分Seq/Time 已经写入)
// 强制要求 Schema
if schema == nil {
dataJSON, err := json.Marshal(row.Data)
if err != nil {
return nil, err
}
buf.Write(dataJSON)
return buf.Bytes(), nil
return nil, fmt.Errorf("schema is required for encoding SSTable rows")
}
// 按字段分别编码和压缩
@@ -191,8 +180,8 @@ func encodeSSTableRowBinary(row *SSTableRow, schema *Schema) ([]byte, error) {
return nil, err
}
// 1. 先编码所有字段到各自的 buffer
compressedFields := make([][]byte, len(schema.Fields))
// 1. 先编码所有字段到各自的 buffer(无压缩)
fieldData := make([][]byte, len(schema.Fields))
for i, field := range schema.Fields {
fieldBuf := new(bytes.Buffer)
@@ -209,28 +198,28 @@ func encodeSSTableRowBinary(row *SSTableRow, schema *Schema) ([]byte, error) {
}
}
// 压缩每个字段
compressedFields[i] = snappy.Encode(nil, fieldBuf.Bytes())
// 直接使用二进制数据(无压缩)
fieldData[i] = fieldBuf.Bytes()
}
// 2. 写入字段偏移表(相对于数据区起始位置)
currentOffset := 0
for _, compressed := range compressedFields {
for _, data := range fieldData {
// 写入字段偏移(相对于数据区)
if err := binary.Write(buf, binary.LittleEndian, uint32(currentOffset)); err != nil {
return nil, err
}
// 写入压缩后大小
if err := binary.Write(buf, binary.LittleEndian, uint32(len(compressed))); err != nil {
// 写入数据大小
if err := binary.Write(buf, binary.LittleEndian, uint32(len(data))); err != nil {
return nil, err
}
currentOffset += len(compressed)
currentOffset += len(data)
}
// 3. 写入压缩后的字段数据
for _, compressed := range compressedFields {
if _, err := buf.Write(compressed); err != nil {
// 3. 写入字段数据
for _, data := range fieldData {
if _, err := buf.Write(data); err != nil {
return nil, err
}
}
@@ -341,14 +330,9 @@ func decodeSSTableRowBinaryPartial(data []byte, schema *Schema, fields []string)
return nil, err
}
// 如果没有 Schema,回退到 JSON 解码(读取剩余的 JSON 数据)
// 强制要求 Schema
if schema == nil {
remainingData := data[16:] // Skip Seq (8 bytes) + Time (8 bytes)
row.Data = make(map[string]any)
if err := json.Unmarshal(remainingData, &row.Data); err != nil {
return nil, err
}
return row, nil
return nil, fmt.Errorf("schema is required for decoding SSTable rows")
}
// 读取字段数量
@@ -400,28 +384,22 @@ func decodeSSTableRowBinaryPartial(data []byte, schema *Schema, fields []string)
continue
}
// 读取压缩的字段数据
// 读取字段数据(无压缩)
info := fieldInfos[i]
compressedPos := dataStart + int64(info.offset)
fieldPos := dataStart + int64(info.offset)
// Seek 到字段位置
if _, err := buf.Seek(compressedPos, 0); err != nil {
if _, err := buf.Seek(fieldPos, 0); err != nil {
return nil, fmt.Errorf("seek to field %s: %w", field.Name, err)
}
compressedData := make([]byte, info.size)
if _, err := buf.Read(compressedData); err != nil {
fieldData := make([]byte, info.size)
if _, err := buf.Read(fieldData); err != nil {
return nil, fmt.Errorf("read field %s: %w", field.Name, err)
}
// 解字段数据
decompressed, err := snappy.Decode(nil, compressedData)
if err != nil {
return nil, fmt.Errorf("decompress field %s: %w", field.Name, err)
}
// 解析字段值
fieldBuf := bytes.NewReader(decompressed)
// 解字段值(直接从二进制数据
fieldBuf := bytes.NewReader(fieldData)
value, err := readFieldBinaryValue(fieldBuf, field.Type, true)
if err != nil {
return nil, fmt.Errorf("parse field %s: %w", field.Name, err)
@@ -489,31 +467,29 @@ func readFieldBinaryValue(buf *bytes.Reader, typ FieldType, keep bool) (any, err
// SSTableWriter SST 文件写入器
type SSTableWriter struct {
file *os.File
builder *BTreeBuilder
dataOffset int64
dataStart int64 // 数据起始位置
rowCount int64
minKey int64
maxKey int64
minTime int64
maxTime int64
compression uint8
schema *Schema // Schema 用于优化编码
file *os.File
builder *BTreeBuilder
dataOffset int64
dataStart int64 // 数据起始位置
rowCount int64
minKey int64
maxKey int64
minTime int64
maxTime int64
schema *Schema // Schema 用于优化编码
}
// NewSSTableWriter 创建 SST 写入器
func NewSSTableWriter(file *os.File, schema *Schema) *SSTableWriter {
return &SSTableWriter{
file: file,
builder: NewBTreeBuilder(file, SSTableHeaderSize),
dataOffset: 0, // 先写数据,后面会更新
compression: SSTableCompressionSnappy,
minKey: -1,
maxKey: -1,
minTime: -1,
maxTime: -1,
schema: schema,
file: file,
builder: NewBTreeBuilder(file, SSTableHeaderSize),
dataOffset: 0, // 先写数据,后面会更新
minKey: -1,
maxKey: -1,
minTime: -1,
maxTime: -1,
schema: schema,
}
}
@@ -541,18 +517,13 @@ func (w *SSTableWriter) Add(row *SSTableRow) error {
}
w.rowCount++
// 序列化数据(使用 Schema 优化的二进制格式)
data := encodeSSTableRow(row, w.schema)
// 压缩数据
var compressed []byte
if w.compression == SSTableCompressionSnappy {
compressed = snappy.Encode(nil, data)
} else {
compressed = data
// 序列化数据(使用 Schema 优化的二进制格式,无压缩
data, err := encodeSSTableRow(row, w.schema)
if err != nil {
return fmt.Errorf("encode row: %w", err)
}
// 写入数据块
// 写入数据块(不压缩)
// 第一次写入时,确定数据起始位置
if w.dataStart == 0 {
// 预留足够空间给 B+Tree 索引
@@ -563,19 +534,19 @@ func (w *SSTableWriter) Add(row *SSTableRow) error {
}
offset := w.dataOffset
_, err := w.file.WriteAt(compressed, offset)
_, err = w.file.WriteAt(data, offset)
if err != nil {
return err
}
// 添加到 B+Tree
err = w.builder.Add(row.Seq, offset, int32(len(compressed)))
err = w.builder.Add(row.Seq, offset, int32(len(data)))
if err != nil {
return err
}
// 更新数据偏移
w.dataOffset += int64(len(compressed))
w.dataOffset += int64(len(data))
return nil
}
@@ -595,7 +566,7 @@ func (w *SSTableWriter) Finish() error {
header := &SSTableHeader{
Magic: SSTableMagicNumber,
Version: SSTableVersion,
Compression: w.compression,
Compression: 0, // 不使用压缩(保留字段用于向后兼容)
IndexOffset: SSTableHeaderSize,
IndexSize: indexSize,
RootOffset: rootOffset,
@@ -620,19 +591,13 @@ func (w *SSTableWriter) Finish() error {
}
// encodeSSTableRow 编码行数据 (使用二进制格式)
func encodeSSTableRow(row *SSTableRow, schema *Schema) []byte {
func encodeSSTableRow(row *SSTableRow, schema *Schema) ([]byte, error) {
// 使用二进制格式编码
encoded, err := encodeSSTableRowBinary(row, schema)
if err != nil {
// 降级到 JSON (不应该发生)
data := map[string]any{
"_seq": row.Seq,
"_time": row.Time,
"data": row.Data,
}
encoded, _ = json.Marshal(data)
return nil, fmt.Errorf("failed to encode row: %w", err)
}
return encoded
return encoded, nil
}
// SSTableReader SST 文件读取器
@@ -704,21 +669,9 @@ func (r *SSTableReader) Get(key int64) (*SSTableRow, error) {
return nil, fmt.Errorf("invalid data offset")
}
compressed := r.mmap[dataOffset : dataOffset+int64(dataSize)]
data := r.mmap[dataOffset : dataOffset+int64(dataSize)]
// 4. 压缩
var data []byte
var err error
if r.header.Compression == SSTableCompressionSnappy {
data, err = snappy.Decode(nil, compressed)
if err != nil {
return nil, err
}
} else {
data = compressed
}
// 5. 反序列化
// 4. 反序列化(无压缩
row, err := decodeSSTableRow(data, r.schema)
if err != nil {
return nil, err
@@ -745,21 +698,9 @@ func (r *SSTableReader) GetPartial(key int64, fields []string) (*SSTableRow, err
return nil, fmt.Errorf("invalid data offset")
}
compressed := r.mmap[dataOffset : dataOffset+int64(dataSize)]
data := r.mmap[dataOffset : dataOffset+int64(dataSize)]
// 4. 压缩
var data []byte
var err error
if r.header.Compression == SSTableCompressionSnappy {
data, err = snappy.Decode(nil, compressed)
if err != nil {
return nil, err
}
} else {
data = compressed
}
// 5. 按需反序列化(只解析需要的字段)
// 4. 按需反序列化(只解析需要的字段,无压缩
row, err := decodeSSTableRowBinaryPartial(data, r.schema, fields)
if err != nil {
return nil, err
@@ -799,27 +740,13 @@ func (r *SSTableReader) Close() error {
return nil
}
// decodeSSTableRow 解码行数据
// decodeSSTableRow 解码行数据(只支持二进制格式)
func decodeSSTableRow(data []byte, schema *Schema) (*SSTableRow, error) {
// 尝试使用二进制格式解码
// 使用二进制格式解码
row, err := decodeSSTableRowBinary(data, schema)
if err == nil {
return row, nil
}
// 降级到 JSON (兼容旧数据)
var decoded map[string]any
err = json.Unmarshal(data, &decoded)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to decode row: %w", err)
}
row = &SSTableRow{
Seq: int64(decoded["_seq"].(float64)),
Time: int64(decoded["_time"].(float64)),
Data: decoded["data"].(map[string]any),
}
return row, nil
}

View File

@@ -15,8 +15,14 @@ func TestSSTable(t *testing.T) {
}
defer os.Remove("test.sst")
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "name", Type: FieldTypeString},
{Name: "age", Type: FieldTypeInt64},
})
// 2. 写入数据
writer := NewSSTableWriter(file, nil)
writer := NewSSTableWriter(file, schema)
// 添加 1000 行数据
for i := int64(1); i <= 1000; i++ {
@@ -51,6 +57,9 @@ func TestSSTable(t *testing.T) {
}
defer reader.Close()
// 设置 Schema
reader.SetSchema(schema)
// 验证 Header
header := reader.GetHeader()
if header.RowCount != 1000 {
@@ -100,7 +109,7 @@ func TestSSTableHeaderSerialization(t *testing.T) {
header := &SSTableHeader{
Magic: SSTableMagicNumber,
Version: SSTableVersion,
Compression: SSTableCompressionSnappy,
Compression: 0, // 不使用压缩
IndexOffset: 256,
IndexSize: 1024,
RootOffset: 512,
@@ -154,11 +163,16 @@ func TestSSTableHeaderSerialization(t *testing.T) {
}
func BenchmarkSSTableGet(b *testing.B) {
// 创建 Schema
schema := NewSchema("test", []Field{
{Name: "value", Type: FieldTypeInt64},
})
// 创建测试文件
file, _ := os.Create("bench.sst")
defer os.Remove("bench.sst")
writer := NewSSTableWriter(file, nil)
writer := NewSSTableWriter(file, schema)
for i := int64(1); i <= 10000; i++ {
row := &SSTableRow{
Seq: i,
@@ -176,6 +190,9 @@ func BenchmarkSSTableGet(b *testing.B) {
reader, _ := NewSSTableReader("bench.sst")
defer reader.Close()
// 设置 Schema
reader.SetSchema(schema)
// 性能测试
for i := 0; b.Loop(); i++ {

815
table.go
View File

@@ -1,82 +1,720 @@
package srdb
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
"sync/atomic"
"time"
)
const (
DefaultMemTableSize = 64 * 1024 * 1024 // 64 MB
DefaultAutoFlushTimeout = 30 * time.Second // 30 秒无写入自动 flush
)
// Table 表
type Table struct {
name string // 表名
dir string // 表目录
schema *Schema // Schema
engine *Engine // Engine 实例
createdAt int64 // 创建时间
dir string
schema *Schema
indexManager *IndexManager
walManager *WALManager // WAL 管理器
sstManager *SSTableManager // SST 管理器
memtableManager *MemTableManager // MemTable 管理器
versionSet *VersionSet // MANIFEST 管理器
compactionManager *CompactionManager // Compaction 管理器
seq atomic.Int64
flushMu sync.Mutex
// 自动 flush 相关
autoFlushTimeout time.Duration
lastWriteTime atomic.Int64 // 最后写入时间UnixNano
stopAutoFlush chan struct{}
}
// createTable 创建新表
func createTable(name string, schema *Schema, db *Database) (*Table, error) {
// 创建表目录
tableDir := filepath.Join(db.dir, name)
err := os.MkdirAll(tableDir, 0755)
// TableOptions 配置选项
type TableOptions struct {
Dir string
MemTableSize int64
Name string // 表名
Fields []Field // 字段列表(可选)
AutoFlushTimeout time.Duration // 自动 flush 超时时间0 表示禁用
}
// OpenTable 打开数据库
func OpenTable(opts *TableOptions) (*Table, error) {
if opts.MemTableSize == 0 {
opts.MemTableSize = DefaultMemTableSize
}
// 创建主目录
err := os.MkdirAll(opts.Dir, 0755)
if err != nil {
os.RemoveAll(tableDir)
return nil, err
}
// 创建 EngineEngine 会自动保存 Schema 到文件)
engine, err := OpenEngine(&EngineOptions{
Dir: tableDir,
MemTableSize: DefaultMemTableSize,
Schema: schema,
})
// 创建子目录
walDir := filepath.Join(opts.Dir, "wal")
sstDir := filepath.Join(opts.Dir, "sst")
idxDir := filepath.Join(opts.Dir, "idx")
err = os.MkdirAll(walDir, 0755)
if err != nil {
return nil, err
}
err = os.MkdirAll(sstDir, 0755)
if err != nil {
return nil, err
}
err = os.MkdirAll(idxDir, 0755)
if err != nil {
os.RemoveAll(tableDir)
return nil, err
}
// 处理 Schema
var sch *Schema
if opts.Name != "" && len(opts.Fields) > 0 {
// 从 Name 和 Fields 创建 Schema
sch = NewSchema(opts.Name, opts.Fields)
// 保存到磁盘(带校验和)
schemaPath := filepath.Join(opts.Dir, "schema.json")
schemaFile, err := NewSchemaFile(sch)
if err != nil {
return nil, fmt.Errorf("create schema file: %w", err)
}
schemaData, err := json.MarshalIndent(schemaFile, "", " ")
if err != nil {
return nil, fmt.Errorf("marshal schema: %w", err)
}
err = os.WriteFile(schemaPath, schemaData, 0644)
if err != nil {
return nil, fmt.Errorf("write schema: %w", err)
}
} else {
// 尝试从磁盘恢复
schemaPath := filepath.Join(opts.Dir, "schema.json")
schemaData, err := os.ReadFile(schemaPath)
if err == nil {
// 文件存在,尝试解析
schemaFile := &SchemaFile{}
err = json.Unmarshal(schemaData, schemaFile)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal schema from %s: %w", schemaPath, err)
}
// 验证校验和
err = schemaFile.Verify()
if err != nil {
return nil, fmt.Errorf("failed to verify schema from %s: %w", schemaPath, err)
}
sch = schemaFile.Schema
} else if !os.IsNotExist(err) {
// 其他读取错误
return nil, fmt.Errorf("failed to read schema file %s: %w", schemaPath, err)
} else {
// Schema 文件不存在
return nil, fmt.Errorf("schema is required but schema.json not found in %s", opts.Dir)
}
}
// 强制要求 Schema
if sch == nil {
return nil, fmt.Errorf("schema is required to open table")
}
// 创建索引管理器
indexMgr := NewIndexManager(idxDir, sch)
// 创建 SST Manager
sstMgr, err := NewSSTableManager(sstDir)
if err != nil {
return nil, err
}
// 设置 Schema用于优化编解码
sstMgr.SetSchema(sch)
// 创建 MemTable Manager
memMgr := NewMemTableManager(opts.MemTableSize)
// 创建/恢复 MANIFEST
manifestDir := opts.Dir
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
return nil, fmt.Errorf("create version set: %w", err)
}
// 创建 Table暂时不设置 WAL Manager
table := &Table{
name: name,
dir: tableDir,
schema: schema,
engine: engine,
createdAt: time.Now().Unix(),
dir: opts.Dir,
schema: sch,
indexManager: indexMgr,
walManager: nil, // 先不设置,恢复后再创建
sstManager: sstMgr,
memtableManager: memMgr,
versionSet: versionSet,
}
// 先恢复数据(包括从 WAL 恢复)
err = table.recover()
if err != nil {
return nil, err
}
// 恢复完成后,创建 WAL Manager 用于后续写入
walMgr, err := NewWALManager(walDir)
if err != nil {
return nil, err
}
table.walManager = walMgr
table.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber())
// 创建 Compaction Manager
table.compactionManager = NewCompactionManager(sstDir, versionSet, sstMgr)
// 设置 Schema
table.compactionManager.SetSchema(sch)
// 启动时清理孤儿文件(崩溃恢复后的清理)
table.compactionManager.CleanupOrphanFiles()
// 启动后台 Compaction 和垃圾回收
table.compactionManager.Start()
// 验证并修复索引
table.verifyAndRepairIndexes()
// 设置自动 flush 超时时间
if opts.AutoFlushTimeout > 0 {
table.autoFlushTimeout = opts.AutoFlushTimeout
} else {
table.autoFlushTimeout = DefaultAutoFlushTimeout
}
table.stopAutoFlush = make(chan struct{})
table.lastWriteTime.Store(time.Now().UnixNano())
// 启动自动 flush 监控
go table.autoFlushMonitor()
return table, nil
}
// openTable 打开已存在的表
func openTable(name string, db *Database) (*Table, error) {
tableDir := filepath.Join(db.dir, name)
// Insert 插入数据
func (t *Table) Insert(data map[string]any) error {
// 1. 验证 Schema
if err := t.schema.Validate(data); err != nil {
return NewError(ErrCodeSchemaValidationFailed, err)
}
// 打开 EngineEngine 会自动从 schema.json 恢复 Schema
eng, err := OpenEngine(&EngineOptions{
Dir: tableDir,
MemTableSize: DefaultMemTableSize,
// Schema 不设置,让 Engine 自动从磁盘恢复
})
// 2. 生成 _seq
seq := t.seq.Add(1)
// 3. 添加系统字段
row := &SSTableRow{
Seq: seq,
Time: time.Now().UnixNano(),
Data: data,
}
// 3. 序列化
rowData, err := json.Marshal(row)
if err != nil {
return nil, err
return err
}
// 从 Engine 获取 Schema
sch := eng.GetSchema()
table := &Table{
name: name,
dir: tableDir,
schema: sch,
engine: eng,
// 4. 写入 WAL
entry := &WALEntry{
Type: WALEntryTypePut,
Seq: seq,
Data: rowData,
}
err = t.walManager.Append(entry)
if err != nil {
return err
}
return table, nil
// 5. 写入 MemTable Manager
t.memtableManager.Put(seq, rowData)
// 6. 添加到索引
t.indexManager.AddToIndexes(data, seq)
// 7. 更新最后写入时间
t.lastWriteTime.Store(time.Now().UnixNano())
// 8. 检查是否需要切换 MemTable
if t.memtableManager.ShouldSwitch() {
go t.switchMemTable()
}
return nil
}
// Get 查询数据
func (t *Table) Get(seq int64) (*SSTableRow, error) {
// 1. 先查 MemTable Manager (Active + Immutables)
data, found := t.memtableManager.Get(seq)
if found {
var row SSTableRow
err := json.Unmarshal(data, &row)
if err != nil {
return nil, err
}
return &row, nil
}
// 2. 查询 SST 文件
return t.sstManager.Get(seq)
}
// GetPartial 按需查询数据(只读取指定字段)
func (t *Table) GetPartial(seq int64, fields []string) (*SSTableRow, error) {
// 1. 先查 MemTable Manager (Active + Immutables)
data, found := t.memtableManager.Get(seq)
if found {
var row SSTableRow
err := json.Unmarshal(data, &row)
if err != nil {
return nil, err
}
// MemTable 中的数据已经完全解析,需要手动过滤字段
if len(fields) > 0 {
filteredData := make(map[string]any)
for _, field := range fields {
if val, ok := row.Data[field]; ok {
filteredData[field] = val
}
}
row.Data = filteredData
}
return &row, nil
}
// 2. 查询 SST 文件(按需解码)
return t.sstManager.GetPartial(seq, fields)
}
// switchMemTable 切换 MemTable
func (t *Table) switchMemTable() error {
t.flushMu.Lock()
defer t.flushMu.Unlock()
// 1. 切换到新的 WAL
oldWALNumber, err := t.walManager.Rotate()
if err != nil {
return err
}
newWALNumber := t.walManager.GetCurrentNumber()
// 2. 切换 MemTable (Active → Immutable)
_, immutable := t.memtableManager.Switch(newWALNumber)
// 3. 异步 Flush Immutable
go t.flushImmutable(immutable, oldWALNumber)
return nil
}
// flushImmutable 将 Immutable MemTable 刷新到 SST
func (t *Table) flushImmutable(imm *ImmutableMemTable, walNumber int64) error {
// 1. 收集所有行
var rows []*SSTableRow
iter := imm.NewIterator()
for iter.Next() {
var row SSTableRow
err := json.Unmarshal(iter.Value(), &row)
if err == nil {
rows = append(rows, &row)
}
}
if len(rows) == 0 {
// 没有数据,直接清理
t.walManager.Delete(walNumber)
t.memtableManager.RemoveImmutable(imm)
return nil
}
// 2. 从 VersionSet 分配文件编号
fileNumber := t.versionSet.AllocateFileNumber()
// 3. 创建 SST 文件到 L0
reader, err := t.sstManager.CreateSST(fileNumber, rows)
if err != nil {
return err
}
// 4. 创建 FileMetadata
header := reader.GetHeader()
// 获取文件大小
sstPath := reader.GetPath()
fileInfo, err := os.Stat(sstPath)
if err != nil {
return fmt.Errorf("stat sst file: %w", err)
}
fileMeta := &FileMetadata{
FileNumber: fileNumber,
Level: 0, // Flush 到 L0
FileSize: fileInfo.Size(),
MinKey: header.MinKey,
MaxKey: header.MaxKey,
RowCount: header.RowCount,
}
// 5. 更新 MANIFEST
edit := NewVersionEdit()
edit.AddFile(fileMeta)
// 持久化当前的文件编号计数器(关键修复:防止重启后文件编号重用)
// 使用 fileNumber + 1 确保并发安全,避免竞态条件
edit.SetNextFileNumber(fileNumber + 1)
err = t.versionSet.LogAndApply(edit)
if err != nil {
return fmt.Errorf("log and apply version edit: %w", err)
}
// 6. 删除对应的 WAL
t.walManager.Delete(walNumber)
// 7. 从 Immutable 列表中移除
t.memtableManager.RemoveImmutable(imm)
// 8. 持久化索引(防止崩溃丢失索引数据)
t.indexManager.BuildAll()
// 9. Compaction 由后台线程负责,不在 flush 路径中触发
// 避免同步 compaction 导致刚创建的文件立即被删除
// t.compactionManager.MaybeCompact()
return nil
}
// recover 恢复数据
func (t *Table) recover() error {
// 1. 恢复 SST 文件SST Manager 已经在 NewManager 中恢复了)
// 只需要获取最大 seq
maxSeq := t.sstManager.GetMaxSeq()
if maxSeq > t.seq.Load() {
t.seq.Store(maxSeq)
}
// 2. 恢复所有 WAL 文件到 MemTable Manager
walDir := filepath.Join(t.dir, "wal")
pattern := filepath.Join(walDir, "*.wal")
walFiles, err := filepath.Glob(pattern)
if err == nil && len(walFiles) > 0 {
// 按文件名排序
sort.Strings(walFiles)
// 依次读取每个 WAL
for _, walPath := range walFiles {
reader, err := NewWALReader(walPath)
if err != nil {
continue
}
entries, err := reader.Read()
reader.Close()
if err != nil {
continue
}
// 重放 WAL 到 Active MemTable
for _, entry := range entries {
// 验证 Schema
var row SSTableRow
if err := json.Unmarshal(entry.Data, &row); err != nil {
return fmt.Errorf("failed to unmarshal row during recovery (seq=%d): %w", entry.Seq, err)
}
// 验证 Schema
if err := t.schema.Validate(row.Data); err != nil {
return NewErrorf(ErrCodeSchemaValidationFailed, "schema validation failed during recovery (seq=%d)", entry.Seq, err)
}
t.memtableManager.Put(entry.Seq, entry.Data)
if entry.Seq > t.seq.Load() {
t.seq.Store(entry.Seq)
}
}
}
}
return nil
}
// autoFlushMonitor 自动 flush 监控
func (t *Table) autoFlushMonitor() {
ticker := time.NewTicker(t.autoFlushTimeout / 2) // 每半个超时时间检查一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 检查是否超时
lastWrite := time.Unix(0, t.lastWriteTime.Load())
if time.Since(lastWrite) >= t.autoFlushTimeout {
// 检查 MemTable 是否有数据
active := t.memtableManager.GetActive()
if active != nil && active.Size() > 0 {
// 触发 flush
t.Flush()
}
}
case <-t.stopAutoFlush:
return
}
}
}
// Flush 手动刷新 Active MemTable 到磁盘
func (t *Table) Flush() error {
// 检查 Active MemTable 是否有数据
active := t.memtableManager.GetActive()
if active == nil || active.Size() == 0 {
return nil // 没有数据,无需 flush
}
// 强制切换 MemTableswitchMemTable 内部有锁)
return t.switchMemTable()
}
// Close 关闭引擎
func (t *Table) Close() error {
// 1. 停止自动 flush 监控(如果还在运行)
if t.stopAutoFlush != nil {
select {
case <-t.stopAutoFlush:
// 已经关闭,跳过
default:
close(t.stopAutoFlush)
}
}
// 2. 停止 Compaction Manager
if t.compactionManager != nil {
t.compactionManager.Stop()
}
// 3. 刷新 Active MemTable确保所有数据都写入磁盘
// 检查 memtableManager 是否存在(可能已被 Destroy
if t.memtableManager != nil {
t.Flush()
}
// 3. 关闭 WAL Manager
if t.walManager != nil {
t.walManager.Close()
}
// 4. 等待所有 Immutable Flush 完成
// TODO: 添加更优雅的等待机制
if t.memtableManager != nil {
for t.memtableManager.GetImmutableCount() > 0 {
time.Sleep(100 * time.Millisecond)
}
}
// 5. 保存所有索引
if t.indexManager != nil {
t.indexManager.BuildAll()
t.indexManager.Close()
}
// 6. 关闭 VersionSet
if t.versionSet != nil {
t.versionSet.Close()
}
// 7. 关闭 WAL Manager
if t.walManager != nil {
t.walManager.Close()
}
// 6. 关闭 SST Manager
if t.sstManager != nil {
t.sstManager.Close()
}
return nil
}
// Clean 清除所有数据(保留 Table 可用)
func (t *Table) Clean() error {
t.flushMu.Lock()
defer t.flushMu.Unlock()
// 0. 停止自动 flush 监控(临时)
if t.stopAutoFlush != nil {
close(t.stopAutoFlush)
}
// 1. 停止 Compaction Manager
if t.compactionManager != nil {
t.compactionManager.Stop()
}
// 2. 等待所有 Immutable Flush 完成
for t.memtableManager.GetImmutableCount() > 0 {
time.Sleep(100 * time.Millisecond)
}
// 3. 清空 MemTable
t.memtableManager = NewMemTableManager(DefaultMemTableSize)
// 2. 删除所有 WAL 文件
if t.walManager != nil {
t.walManager.Close()
walDir := filepath.Join(t.dir, "wal")
os.RemoveAll(walDir)
os.MkdirAll(walDir, 0755)
// 重新创建 WAL Manager
walMgr, err := NewWALManager(walDir)
if err != nil {
return fmt.Errorf("recreate wal manager: %w", err)
}
t.walManager = walMgr
t.memtableManager.SetActiveWAL(walMgr.GetCurrentNumber())
}
// 3. 删除所有 SST 文件
if t.sstManager != nil {
t.sstManager.Close()
sstDir := filepath.Join(t.dir, "sst")
os.RemoveAll(sstDir)
os.MkdirAll(sstDir, 0755)
// 重新创建 SST Manager
sstMgr, err := NewSSTableManager(sstDir)
if err != nil {
return fmt.Errorf("recreate sst manager: %w", err)
}
t.sstManager = sstMgr
// 设置 Schema
t.sstManager.SetSchema(t.schema)
}
// 4. 删除所有索引文件
if t.indexManager != nil {
t.indexManager.Close()
indexFiles, _ := filepath.Glob(filepath.Join(t.dir, "idx_*.sst"))
for _, f := range indexFiles {
os.Remove(f)
}
// 重新创建 Index Manager
t.indexManager = NewIndexManager(t.dir, t.schema)
}
// 5. 重置 MANIFEST
if t.versionSet != nil {
t.versionSet.Close()
manifestDir := t.dir
os.Remove(filepath.Join(manifestDir, "MANIFEST"))
os.Remove(filepath.Join(manifestDir, "CURRENT"))
// 重新创建 VersionSet
versionSet, err := NewVersionSet(manifestDir)
if err != nil {
return fmt.Errorf("recreate version set: %w", err)
}
t.versionSet = versionSet
}
// 6. 重新创建 Compaction Manager
sstDir := filepath.Join(t.dir, "sst")
t.compactionManager = NewCompactionManager(sstDir, t.versionSet, t.sstManager)
t.compactionManager.SetSchema(t.schema)
t.compactionManager.Start()
// 7. 重置序列号
t.seq.Store(0)
// 8. 更新最后写入时间
t.lastWriteTime.Store(time.Now().UnixNano())
// 9. 重启自动 flush 监控
t.stopAutoFlush = make(chan struct{})
go t.autoFlushMonitor()
return nil
}
// Destroy 销毁 Table 并删除所有数据文件
func (t *Table) Destroy() error {
// 1. 先关闭 Table
if err := t.Close(); err != nil {
return fmt.Errorf("close table: %w", err)
}
// 2. 删除整个数据目录
if err := os.RemoveAll(t.dir); err != nil {
return fmt.Errorf("remove data directory: %w", err)
}
// 3. 标记 Table 为不可用(将所有管理器设为 nil
t.walManager = nil
t.sstManager = nil
t.memtableManager = nil
t.versionSet = nil
t.compactionManager = nil
t.indexManager = nil
return nil
}
// TableStats 统计信息
type TableStats struct {
MemTableSize int64
MemTableCount int
SSTCount int
TotalRows int64
}
// GetVersionSet 获取 VersionSet用于高级操作
func (t *Table) GetVersionSet() *VersionSet {
return t.versionSet
}
// GetCompactionManager 获取 Compaction Manager用于高级操作
func (t *Table) GetCompactionManager() *CompactionManager {
return t.compactionManager
}
// GetMemtableManager 获取 Memtable Manager
func (t *Table) GetMemtableManager() *MemTableManager {
return t.memtableManager
}
// GetSSTManager 获取 SST Manager
func (t *Table) GetSSTManager() *SSTableManager {
return t.sstManager
}
// GetMaxSeq 获取当前最大的 seq 号
func (t *Table) GetMaxSeq() int64 {
return t.seq.Load() - 1 // seq 是下一个要分配的,所以最大的是 seq - 1
}
// GetName 获取表名
func (t *Table) GetName() string {
return t.name
return t.schema.Name
}
// GetDir 获取表目录
func (t *Table) GetDir() string {
return t.dir
}
// GetSchema 获取 Schema
@@ -84,71 +722,72 @@ func (t *Table) GetSchema() *Schema {
return t.schema
}
// Insert 插入数据
func (t *Table) Insert(data map[string]any) error {
return t.engine.Insert(data)
}
// Stats 获取统计信息
func (t *Table) Stats() *TableStats {
memStats := t.memtableManager.GetStats()
sstStats := t.sstManager.GetStats()
// Get 查询数据
func (t *Table) Get(seq int64) (*SSTableRow, error) {
return t.engine.Get(seq)
}
stats := &TableStats{
MemTableSize: memStats.TotalSize,
MemTableCount: memStats.TotalCount,
SSTCount: sstStats.FileCount,
}
// Query 创建查询构建器
func (t *Table) Query() *QueryBuilder {
return t.engine.Query()
// 计算总行数
stats.TotalRows = int64(memStats.TotalCount)
readers := t.sstManager.GetReaders()
for _, reader := range readers {
header := reader.GetHeader()
stats.TotalRows += header.RowCount
}
return stats
}
// CreateIndex 创建索引
func (t *Table) CreateIndex(field string) error {
return t.engine.CreateIndex(field)
return t.indexManager.CreateIndex(field)
}
// DropIndex 删除索引
func (t *Table) DropIndex(field string) error {
return t.engine.DropIndex(field)
return t.indexManager.DropIndex(field)
}
// ListIndexes 列出所有索引
func (t *Table) ListIndexes() []string {
return t.engine.ListIndexes()
return t.indexManager.ListIndexes()
}
// Stats 获取统计信息
func (t *Table) Stats() *TableStats {
return t.engine.Stats()
// GetIndexMetadata 获取索引元数据
func (t *Table) GetIndexMetadata() map[string]IndexMetadata {
return t.indexManager.GetIndexMetadata()
}
// GetEngine 获取底层 Engine
func (t *Table) GetEngine() *Engine {
return t.engine
// RepairIndexes 手动修复索引
func (t *Table) RepairIndexes() error {
return t.verifyAndRepairIndexes()
}
// Close 关闭表
func (t *Table) Close() error {
if t.engine != nil {
return t.engine.Close()
// Query 创建查询构建器
func (t *Table) Query() *QueryBuilder {
return newQueryBuilder(t)
}
// verifyAndRepairIndexes 验证并修复索引
func (t *Table) verifyAndRepairIndexes() error {
// 获取当前最大 seq
currentMaxSeq := t.seq.Load()
// 创建 getData 函数
getData := func(seq int64) (map[string]any, error) {
row, err := t.Get(seq)
if err != nil {
return nil, err
}
return row.Data, nil
}
return nil
}
// GetCreatedAt 获取表创建时间
func (t *Table) GetCreatedAt() int64 {
return t.createdAt
}
// Clean 清除表的所有数据(保留表结构和 Table 可用)
func (t *Table) Clean() error {
if t.engine != nil {
return t.engine.Clean()
}
return nil
}
// Destroy 销毁表并删除所有数据文件(不从 Database 中删除)
func (t *Table) Destroy() error {
if t.engine != nil {
return t.engine.Destroy()
}
return nil
// 验证并修复
return t.indexManager.VerifyAndRepair(currentMaxSeq, getData)
}

View File

@@ -1,314 +0,0 @@
package srdb
import (
"os"
"testing"
)
func TestTableClean(t *testing.T) {
dir := "./test_table_clean_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("users", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
table, err := db.CreateTable("users", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := 0; i < 100; i++ {
err := table.Insert(map[string]any{
"id": int64(i),
"name": "user" + string(rune(i)),
})
if err != nil {
t.Fatal(err)
}
}
// 3. 验证数据存在
stats := table.Stats()
t.Logf("Before Clean: %d rows", stats.TotalRows)
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 4. 清除数据
err = table.Clean()
if err != nil {
t.Fatal(err)
}
// 5. 验证数据已清除
stats = table.Stats()
t.Logf("After Clean: %d rows", stats.TotalRows)
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows)
}
// 6. 验证表仍然可用
err = table.Insert(map[string]any{
"id": int64(100),
"name": "new_user",
})
if err != nil {
t.Fatal(err)
}
stats = table.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row after insert, got %d", stats.TotalRows)
}
}
func TestTableDestroy(t *testing.T) {
dir := "./test_table_destroy_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := 0; i < 50; i++ {
table.Insert(map[string]any{"id": int64(i)})
}
// 3. 验证数据存在
stats := table.Stats()
t.Logf("Before Destroy: %d rows", stats.TotalRows)
if stats.TotalRows == 0 {
t.Error("Expected data in table")
}
// 4. 获取表目录路径
tableDir := table.dir
// 5. 销毁表
err = table.Destroy()
if err != nil {
t.Fatal(err)
}
// 6. 验证表目录已删除
if _, err := os.Stat(tableDir); !os.IsNotExist(err) {
t.Error("Table directory should be deleted")
}
// 7. 注意Table.Destroy() 只删除文件,不从 Database 中删除
// 表仍然在 Database 的元数据中,但文件已被删除
tables := db.ListTables()
found := false
for _, name := range tables {
if name == "test" {
found = true
break
}
}
if !found {
t.Error("Table should still be in database metadata (use Database.DestroyTable to remove from metadata)")
}
}
func TestTableCleanWithIndex(t *testing.T) {
dir := "./test_table_clean_index_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("users", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: true, Comment: "ID"},
{Name: "email", Type: FieldTypeString, Indexed: true, Comment: "Email"},
{Name: "name", Type: FieldTypeString, Indexed: false, Comment: "Name"},
})
table, err := db.CreateTable("users", schema)
if err != nil {
t.Fatal(err)
}
// 2. 创建索引
err = table.CreateIndex("id")
if err != nil {
t.Fatal(err)
}
err = table.CreateIndex("email")
if err != nil {
t.Fatal(err)
}
// 3. 插入数据
for i := 0; i < 50; i++ {
table.Insert(map[string]any{
"id": int64(i),
"email": "user" + string(rune(i)) + "@example.com",
"name": "User " + string(rune(i)),
})
}
// 4. 验证索引存在
indexes := table.ListIndexes()
if len(indexes) != 2 {
t.Errorf("Expected 2 indexes, got %d", len(indexes))
}
// 5. 清除数据
err = table.Clean()
if err != nil {
t.Fatal(err)
}
// 6. 验证数据已清除
stats := table.Stats()
if stats.TotalRows != 0 {
t.Errorf("Expected 0 rows after clean, got %d", stats.TotalRows)
}
// 7. 验证索引已被清除Clean 会删除索引数据)
indexes = table.ListIndexes()
if len(indexes) != 0 {
t.Logf("Note: Indexes were cleared (expected behavior), got %d", len(indexes))
}
// 8. 重新创建索引
table.CreateIndex("id")
table.CreateIndex("email")
// 9. 验证可以继续插入数据
err = table.Insert(map[string]any{
"id": int64(100),
"email": "new@example.com",
"name": "New User",
})
if err != nil {
t.Fatal(err)
}
stats = table.Stats()
if stats.TotalRows != 1 {
t.Errorf("Expected 1 row, got %d", stats.TotalRows)
}
}
func TestTableCleanAndQuery(t *testing.T) {
dir := "./test_table_clean_query_data"
defer os.RemoveAll(dir)
// 1. 创建数据库和表
db, err := Open(dir)
if err != nil {
t.Fatal(err)
}
defer db.Close()
schema := NewSchema("test", []Field{
{Name: "id", Type: FieldTypeInt64, Indexed: false, Comment: "ID"},
{Name: "status", Type: FieldTypeString, Indexed: false, Comment: "Status"},
})
table, err := db.CreateTable("test", schema)
if err != nil {
t.Fatal(err)
}
// 2. 插入数据
for i := 0; i < 30; i++ {
table.Insert(map[string]any{
"id": int64(i),
"status": "active",
})
}
// 3. 查询数据
rows, err := table.Query().Eq("status", "active").Rows()
if err != nil {
t.Fatal(err)
}
count := 0
for rows.Next() {
count++
}
rows.Close()
t.Logf("Before Clean: found %d rows", count)
if count != 30 {
t.Errorf("Expected 30 rows, got %d", count)
}
// 4. 清除数据
err = table.Clean()
if err != nil {
t.Fatal(err)
}
// 5. 再次查询
rows, err = table.Query().Eq("status", "active").Rows()
if err != nil {
t.Fatal(err)
}
count = 0
for rows.Next() {
count++
}
rows.Close()
t.Logf("After Clean: found %d rows", count)
if count != 0 {
t.Errorf("Expected 0 rows after clean, got %d", count)
}
// 6. 插入新数据并查询
table.Insert(map[string]any{
"id": int64(100),
"status": "active",
})
rows, err = table.Query().Eq("status", "active").Rows()
if err != nil {
t.Fatal(err)
}
count = 0
for rows.Next() {
count++
}
rows.Close()
if count != 1 {
t.Errorf("Expected 1 row, got %d", count)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ import (
"strconv"
"strings"
"code.tczkiot.com/srdb"
"code.tczkiot.com/wlw/srdb"
)
//go:embed static
@@ -94,7 +94,7 @@ func (ui *WebUI) handleListTables(w http.ResponseWriter, r *http.Request) {
tables = append(tables, TableListItem{
Name: name,
CreatedAt: table.GetCreatedAt(),
CreatedAt: 0, // TODO: Table 不再有 createdAt 字段
Fields: fields,
})
}
@@ -193,8 +193,7 @@ func (ui *WebUI) handleTableManifest(w http.ResponseWriter, r *http.Request, tab
return
}
engine := table.GetEngine()
versionSet := engine.GetVersionSet()
versionSet := table.GetVersionSet()
version := versionSet.GetCurrent()
// 构建每层的信息
@@ -216,7 +215,7 @@ func (ui *WebUI) handleTableManifest(w http.ResponseWriter, r *http.Request, tab
}
// 获取 Compaction Manager 和 Picker
compactionMgr := engine.GetCompactionManager()
compactionMgr := table.GetCompactionManager()
picker := compactionMgr.GetPicker()
levels := make([]LevelInfo, 0)