185 lines
5.2 KiB
Go
185 lines
5.2 KiB
Go
// Package taskq 提供基于 Redis 的异步任务队列功能
|
||
// 使用 asynq 库作为底层实现,支持任务注册、发布、消费和重试机制
|
||
package taskq
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"log"
|
||
"maps"
|
||
"reflect"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/hibiken/asynq"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// 全局状态变量
|
||
var (
|
||
started atomic.Bool // 服务器启动状态
|
||
exit chan chan struct{} // 优雅退出信号通道
|
||
handlers map[string]asynq.Handler // 任务处理器映射表
|
||
queues map[string]int // 队列优先级配置
|
||
client atomic.Pointer[asynq.Client] // asynq 客户端实例
|
||
redisClient redis.UniversalClient // Redis 客户端实例
|
||
errorType = reflect.TypeOf((*error)(nil)).Elem() // error 类型反射
|
||
contextType = reflect.TypeOf((*context.Context)(nil)).Elem() // context.Context 类型反射
|
||
)
|
||
|
||
// Init 初始化 taskq 系统
|
||
// 创建必要的全局变量和映射表,必须在调用其他函数之前调用
|
||
func Init() {
|
||
exit = make(chan chan struct{}) // 创建优雅退出通道
|
||
handlers = make(map[string]asynq.Handler) // 创建任务处理器映射
|
||
queues = make(map[string]int) // 创建队列优先级映射
|
||
}
|
||
|
||
// Register 注册任务处理器
|
||
// 使用泛型确保类型安全,通过反射验证处理器函数签名
|
||
// 处理器函数签名必须是:func(context.Context, T) error 或 func(context.Context) 或 func(T) error 或 func()
|
||
func Register[T any](t *Task[T]) error {
|
||
rv := reflect.ValueOf(t.Handler)
|
||
if rv.Kind() != reflect.Func {
|
||
return errors.New("taskq: handler must be a function")
|
||
}
|
||
|
||
rt := rv.Type()
|
||
|
||
// 验证返回值:只能是 error 或无返回值
|
||
var returnError bool
|
||
for i := range rt.NumOut() {
|
||
if i == 0 && rt.Out(0).Implements(errorType) {
|
||
returnError = true
|
||
} else {
|
||
return errors.New("taskq: handler function must return either error or nothing")
|
||
}
|
||
}
|
||
|
||
// 验证参数:最多2个参数,第一个必须是 context.Context,第二个必须是结构体
|
||
var inContext bool
|
||
var inData bool
|
||
var dataType reflect.Type
|
||
for i := range rt.NumIn() {
|
||
if i == 0 {
|
||
fi := rt.In(i)
|
||
if !fi.Implements(contextType) {
|
||
return errors.New("taskq: handler function first parameter must be context.Context")
|
||
}
|
||
inContext = true
|
||
continue
|
||
}
|
||
if i != 1 {
|
||
return errors.New("taskq: handler function can have at most 2 parameters")
|
||
}
|
||
fi := rt.In(i)
|
||
if fi.Kind() != reflect.Struct {
|
||
return errors.New("taskq: handler function second parameter must be a struct")
|
||
}
|
||
inData = true
|
||
dataType = fi
|
||
}
|
||
|
||
// 检查服务器是否已启动
|
||
if started.Load() {
|
||
return errors.New("taskq: cannot register handler after server has started")
|
||
}
|
||
|
||
// 设置任务的反射信息
|
||
t.funcValue = rv
|
||
t.dataType = dataType
|
||
t.inputContext = inContext
|
||
t.inputData = inData
|
||
t.returnError = returnError
|
||
|
||
// 注册到全局映射表
|
||
handlers[t.Name] = t
|
||
queues[t.Queue] = t.Priority
|
||
|
||
return nil
|
||
}
|
||
|
||
// SetRedis 设置 Redis 客户端
|
||
// 必须在启动服务器之前调用,用于配置任务队列的存储后端
|
||
func SetRedis(rdb redis.UniversalClient) error {
|
||
if started.Load() {
|
||
return errors.New("taskq: server is already running")
|
||
}
|
||
|
||
redisClient = rdb
|
||
client.Store(asynq.NewClientFromRedisClient(rdb))
|
||
|
||
return nil
|
||
}
|
||
|
||
// Start 启动 taskq 服务器
|
||
// 开始监听任务队列并处理任务,包含健康检查和优雅退出机制
|
||
func Start(ctx context.Context) error {
|
||
// 原子操作确保只启动一次
|
||
if !started.CompareAndSwap(false, true) {
|
||
return errors.New("taskq: server is already running")
|
||
}
|
||
|
||
// 检查 Redis 客户端是否已初始化
|
||
if redisClient == nil {
|
||
return errors.New("taskq: redis client not initialized, call SetRedis() first")
|
||
}
|
||
|
||
// 创建任务路由器
|
||
mux := asynq.NewServeMux()
|
||
for name, handler := range handlers {
|
||
mux.Handle(name, handler)
|
||
}
|
||
|
||
// 创建 asynq 服务器
|
||
srv := asynq.NewServerFromRedisClient(redisClient, asynq.Config{
|
||
Concurrency: 30, // 并发处理数
|
||
Queues: maps.Clone(queues), // 队列配置
|
||
BaseContext: func() context.Context { return ctx }, // 基础上下文
|
||
LogLevel: asynq.DebugLevel, // 日志级别
|
||
})
|
||
|
||
// 启动监控协程:处理优雅退出和健康检查
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
go func() {
|
||
defer cancel()
|
||
|
||
ticker := time.NewTicker(time.Minute) // 每分钟健康检查
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case exit := <-exit: // 收到退出信号
|
||
srv.Stop()
|
||
exit <- struct{}{}
|
||
return
|
||
case <-ticker.C: // 定期健康检查
|
||
err := srv.Ping()
|
||
if err != nil {
|
||
log.Println(err)
|
||
Stop()
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 启动任务处理服务器
|
||
go func() {
|
||
if err := srv.Run(mux); err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}()
|
||
|
||
return nil
|
||
}
|
||
|
||
// Stop 优雅停止 taskq 服务器
|
||
// 发送停止信号并等待服务器完全关闭
|
||
func Stop() {
|
||
quit := make(chan struct{})
|
||
exit <- quit // 发送退出信号
|
||
<-quit // 等待确认退出
|
||
}
|