// Package taskq 提供基于 Redis 的异步任务队列功能 // servlet.go 文件包含 Servlet 结构体,封装了任务队列的完整生命周期管理 package taskq import ( "context" "errors" "log" "maps" "reflect" "sync" "github.com/hibiken/asynq" "github.com/redis/go-redis/v9" ) // Config 配置 Servlet 的选项 type Config struct { Redis redis.UniversalClient Tasks []*Task Plugins []Plugin } // Servlet 封装了 taskq 的完整生命周期管理 // // 生命周期: // 1. 配置阶段:调用 Configure 配置 Servlet // 2. 初始化阶段:调用 Init 初始化插件 // 3. 运行阶段:调用 Start 启动服务器 // 4. 停止阶段:调用 Stop 优雅关闭 type Servlet struct { mu sync.RWMutex handlers map[string]asynq.Handler queues map[string]int client *asynq.Client redisClient redis.UniversalClient plugins []Plugin exit chan chan struct{} } // NewServlet 创建一个新的 Servlet 实例 func NewServlet() *Servlet { return &Servlet{ handlers: make(map[string]asynq.Handler), queues: make(map[string]int), exit: make(chan chan struct{}), } } // Configure 配置 Servlet func (s *Servlet) Configure(cfg Config) error { s.mu.Lock() defer s.mu.Unlock() if cfg.Redis == nil { return errors.New("taskq: redis client is required") } s.redisClient = cfg.Redis s.client = asynq.NewClientFromRedisClient(cfg.Redis) // 注册任务 for _, t := range cfg.Tasks { if err := s.registerTask(t); err != nil { return err } } s.plugins = cfg.Plugins return nil } // registerTask 注册单个任务(内部方法,调用时已持有锁) func (s *Servlet) registerTask(t *Task) error { if t.Queue == "" { return errors.New("taskq: queue name cannot be empty") } if t.Priority < 0 || t.Priority > 255 { return errors.New("taskq: priority must be between 0 and 255") } if t.MaxRetries < 0 { return errors.New("taskq: retry count must be non-negative") } if t.Handler == nil { return errors.New("taskq: handler cannot be nil") } rv := reflect.ValueOf(t.Handler) if rv.Kind() != reflect.Func { return errors.New("taskq: handler must be a function") } rt := rv.Type() // 验证返回值:只能是 error 或无返回值 var returnError bool for i := range rt.NumOut() { if i == 0 && rt.Out(0).Implements(errorType) { returnError = true } else { return errors.New("taskq: handler function must return either error or nothing") } } // 验证参数:支持以下签名 // - func(context.Context, T) error // - func(context.Context) error // - func(T) error // - func() var inContext bool var inData bool var dataType reflect.Type numIn := rt.NumIn() if numIn > 2 { return errors.New("taskq: handler function can have at most 2 parameters") } for i := range numIn { fi := rt.In(i) if fi.Implements(contextType) { if i != 0 { return errors.New("taskq: context.Context must be the first parameter") } inContext = true } else if fi.Kind() == reflect.Struct { if inData { return errors.New("taskq: handler function can only have one data parameter") } inData = true dataType = fi } else { return errors.New("taskq: handler parameter must be context.Context or a struct") } } // 设置任务的反射信息 t.funcValue = rv t.dataType = dataType t.inputContext = inContext t.inputData = inData t.returnError = returnError t.servlet = s s.handlers[t.Name] = t s.queues[t.Queue] = t.Priority return nil } // Init 初始化所有插件 func (s *Servlet) Init(ctx context.Context) error { return s.initPlugins(ctx) } // initPlugins 初始化所有插件 func (s *Servlet) initPlugins(ctx context.Context) error { s.mu.RLock() plugins := s.plugins s.mu.RUnlock() pctx := &Context{ Context: ctx, servlet: s, } for _, p := range plugins { if err := p.Init(pctx); err != nil { return err } } return nil } // Start 启动 taskq 服务器 func (s *Servlet) Start(ctx context.Context) error { s.mu.Lock() rdb := s.redisClient qs := maps.Clone(s.queues) s.mu.Unlock() localCtx, cancel := context.WithCancel(ctx) srv := asynq.NewServerFromRedisClient(rdb, asynq.Config{ Concurrency: 30, Queues: qs, BaseContext: func() context.Context { return localCtx }, LogLevel: asynq.WarnLevel, }) // 启动插件 if err := s.startPlugins(localCtx); err != nil { cancel() return err } go s.runServer(localCtx, srv) go s.runMonitor(localCtx, srv, cancel) return nil } // startPlugins 启动所有插件 func (s *Servlet) startPlugins(ctx context.Context) error { s.mu.RLock() plugins := s.plugins s.mu.RUnlock() pctx := &Context{ Context: ctx, servlet: s, } for _, p := range plugins { if err := p.Start(pctx); err != nil { return err } } return nil } // runServer 运行任务处理服务器 func (s *Servlet) runServer(_ context.Context, srv *asynq.Server) { mux := asynq.NewServeMux() s.mu.RLock() for name, handler := range s.handlers { mux.Handle(name, handler) } s.mu.RUnlock() if err := srv.Run(mux); err != nil { log.Printf("taskq: server error: %v", err) } } func (s *Servlet) runMonitor(ctx context.Context, srv *asynq.Server, cancel context.CancelFunc) { var exit chan struct{} select { case <-ctx.Done(): case exit = <-s.exit: } srv.Shutdown() s.stopPlugins() cancel() if exit != nil { exit <- struct{}{} } } // stopPlugins 停止所有插件(按注册的逆序) func (s *Servlet) stopPlugins() { s.mu.RLock() plugins := s.plugins s.mu.RUnlock() for i := len(plugins) - 1; i >= 0; i-- { if err := plugins[i].Stop(); err != nil { log.Printf("taskq: plugin %s stop error: %v", plugins[i].Name(), err) } } } // Stop 优雅停止 taskq 服务器 // 发送停止信号并等待服务器完全关闭 // 可安全地多次调用 func (s *Servlet) Stop() { exit := make(chan struct{}) s.exit <- exit <-exit } // Client 返回 asynq 客户端 func (s *Servlet) Client() *asynq.Client { s.mu.RLock() defer s.mu.RUnlock() return s.client } // RedisClient 返回 Redis 客户端 func (s *Servlet) RedisClient() redis.UniversalClient { s.mu.RLock() defer s.mu.RUnlock() return s.redisClient } // Queues 返回队列优先级配置的副本 func (s *Servlet) Queues() map[string]int { s.mu.RLock() defer s.mu.RUnlock() return maps.Clone(s.queues) }