diff --git a/dag/consumer.go b/dag/consumer.go new file mode 100644 index 0000000..2b238af --- /dev/null +++ b/dag/consumer.go @@ -0,0 +1,101 @@ +package dag + +import ( + "context" + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" + "log" +) + +func (tm *DAG) Consume(ctx context.Context) error { + if tm.consumer != nil { + tm.server.Options().SetSyncMode(true) + return tm.consumer.Consume(ctx) + } + return nil +} + +func (tm *DAG) AssignTopic(topic string) { + tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) + tm.consumerTopic = topic +} + +func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { + if tm.consumer != nil { + result.Topic = tm.consumerTopic + if tm.consumer.Conn() == nil { + tm.onTaskCallback(ctx, result) + } else { + tm.consumer.OnResponse(ctx, result) + } + } +} + +func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { + if node, ok := tm.nodes.Get(topic); ok { + log.Printf("DAG - CONSUMER ~> ready on %s", topic) + node.isReady = true + } +} + +func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { + if node, ok := tm.nodes.Get(topic); ok { + log.Printf("DAG - CONSUMER ~> down on %s", topic) + node.isReady = false + } +} + +func (tm *DAG) Pause(_ context.Context) error { + tm.paused = true + return nil +} + +func (tm *DAG) Resume(_ context.Context) error { + tm.paused = false + return nil +} + +func (tm *DAG) Close() error { + var err error + tm.nodes.ForEach(func(_ string, n *Node) bool { + err = n.processor.Close() + if err != nil { + return false + } + return true + }) + return nil +} + +func (tm *DAG) PauseConsumer(ctx context.Context, id string) { + tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE) +} + +func (tm *DAG) ResumeConsumer(ctx context.Context, id string) { + tm.doConsumer(ctx, id, consts.CONSUMER_RESUME) +} + +func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) { + if node, ok := tm.nodes.Get(id); ok { + switch action { + case consts.CONSUMER_PAUSE: + err := node.processor.Pause(ctx) + if err == nil { + node.isReady = false + log.Printf("[INFO] - Consumer %s paused successfully", node.ID) + } else { + log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err) + } + case consts.CONSUMER_RESUME: + err := node.processor.Resume(ctx) + if err == nil { + node.isReady = true + log.Printf("[INFO] - Consumer %s resumed successfully", node.ID) + } else { + log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err) + } + } + } else { + log.Printf("[WARNING] - Consumer %s not found", id) + } +} diff --git a/dag/dag.go b/dag/dag.go index 1fd1e54..8534bf9 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -11,7 +11,6 @@ import ( "golang.org/x/time/rate" - "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/sio" "github.com/oarkflow/mq" @@ -49,6 +48,7 @@ type DAG struct { opts []mq.Option conditions map[string]map[string]string consumerTopic string + hasPageNode bool reportNodeResultCallback func(mq.Result) Error error Notifier *sio.Server @@ -67,7 +67,11 @@ func NewDAG(name, key string, finalResultCallback func(taskID string, result mq. conditions: make(map[string]map[string]string), finalResult: finalResultCallback, } - opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose)) + opts = append(opts, + mq.WithCallback(d.onTaskCallback), + mq.WithConsumerOnSubscribe(d.onConsumerJoin), + mq.WithConsumerOnClose(d.onConsumerClose), + ) d.server = mq.NewBroker(opts...) d.opts = opts options := d.server.Options() @@ -107,14 +111,6 @@ func (tm *DAG) GetType() string { return tm.key } -func (tm *DAG) Consume(ctx context.Context) error { - if tm.consumer != nil { - tm.server.Options().SetSyncMode(true) - return tm.consumer.Consume(ctx) - } - return nil -} - func (tm *DAG) Stop(ctx context.Context) error { tm.nodes.ForEach(func(_ string, n *Node) bool { err := n.processor.Stop(ctx) @@ -130,66 +126,14 @@ func (tm *DAG) GetKey() string { return tm.key } -func (tm *DAG) AssignTopic(topic string) { - tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) - tm.consumerTopic = topic -} - -func (tm *DAG) callbackToConsumer(ctx context.Context, result mq.Result) { - if tm.consumer != nil { - result.Topic = tm.consumerTopic - if tm.consumer.Conn() == nil { - tm.onTaskCallback(ctx, result) - } else { - tm.consumer.OnResponse(ctx, result) - } - } -} - -func (tm *DAG) onConsumerJoin(_ context.Context, topic, _ string) { - if node, ok := tm.nodes.Get(topic); ok { - log.Printf("DAG - CONSUMER ~> ready on %s", topic) - node.isReady = true - } -} - -func (tm *DAG) onConsumerClose(_ context.Context, topic, _ string) { - if node, ok := tm.nodes.Get(topic); ok { - log.Printf("DAG - CONSUMER ~> down on %s", topic) - node.isReady = false - } -} - -func (tm *DAG) Pause(_ context.Context) error { - tm.paused = true - return nil -} - -func (tm *DAG) Resume(_ context.Context) error { - tm.paused = false - return nil -} - -func (tm *DAG) Close() error { - var err error - tm.nodes.ForEach(func(_ string, n *Node) bool { - err = n.processor.Close() - if err != nil { - return false - } - return true - }) - return nil +func (tm *DAG) SetNotifyResponse(callback mq.Callback) { + tm.server.SetNotifyHandler(callback) } func (tm *DAG) SetStartNode(node string) { tm.startNode = node } -func (tm *DAG) SetNotifyResponse(callback mq.Callback) { - tm.server.SetNotifyHandler(callback) -} - func (tm *DAG) GetStartNode() string { return tm.startNode } @@ -217,6 +161,9 @@ func (tm *DAG) AddNode(nodeType NodeType, name, nodeID string, handler mq.Proces if len(startNode) > 0 && startNode[0] { tm.startNode = nodeID } + if nodeType == Page && !tm.hasPageNode { + tm.hasPageNode = true + } return tm } @@ -299,15 +246,6 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { method, ok := ctx.Value("method").(string) if method == "GET" && exists && node.NodeType == Page { ctx = context.WithValue(ctx, "initial_node", currentNode) - /* - if isLastNode, err := tm.IsLastNode(currentNode); err != nil && isLastNode { - if manager.result != nil { - fmt.Println(string(manager.result.Payload)) - resultCh <- *manager.result - return <-resultCh - } - } - */ if manager.result != nil { task.Payload = manager.result.Payload } @@ -339,7 +277,19 @@ func (tm *DAG) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { task.Topic = firstNode ctx = context.WithValue(ctx, ContextIndex, "0") manager.ProcessTask(ctx, firstNode, task.Payload) - return <-resultCh + if tm.hasPageNode { + return <-resultCh + } + // Timeout handling + select { + case result := <-resultCh: + return result + case <-time.After(30 * time.Second): // Set a timeout duration + return mq.Result{ + Error: fmt.Errorf("timeout waiting for task result"), + Ctx: ctx, + } + } } func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { @@ -483,36 +433,3 @@ func (tm *DAG) ScheduleTask(ctx context.Context, payload []byte, opts ...mq.Sche tm.pool.Scheduler().AddTask(ctxx, t, opts...) return mq.Result{CreatedAt: t.CreatedAt, TaskID: t.ID, Topic: t.Topic, Status: "PENDING"} } - -func (tm *DAG) PauseConsumer(ctx context.Context, id string) { - tm.doConsumer(ctx, id, consts.CONSUMER_PAUSE) -} - -func (tm *DAG) ResumeConsumer(ctx context.Context, id string) { - tm.doConsumer(ctx, id, consts.CONSUMER_RESUME) -} - -func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) { - if node, ok := tm.nodes.Get(id); ok { - switch action { - case consts.CONSUMER_PAUSE: - err := node.processor.Pause(ctx) - if err == nil { - node.isReady = false - log.Printf("[INFO] - Consumer %s paused successfully", node.ID) - } else { - log.Printf("[ERROR] - Failed to pause consumer %s: %v", node.ID, err) - } - case consts.CONSUMER_RESUME: - err := node.processor.Resume(ctx) - if err == nil { - node.isReady = true - log.Printf("[INFO] - Consumer %s resumed successfully", node.ID) - } else { - log.Printf("[ERROR] - Failed to resume consumer %s: %v", node.ID, err) - } - } - } else { - log.Printf("[WARNING] - Consumer %s not found", id) - } -} diff --git a/dag/task_manager.go b/dag/task_manager.go index 63717d7..8ea4384 100644 --- a/dag/task_manager.go +++ b/dag/task_manager.go @@ -379,16 +379,17 @@ func (tm *TaskManager) handleEdges(currentResult nodeResult, edges []Edge) { func (tm *TaskManager) retryDeferredTasks() { const maxRetries = 5 - retries := 0 - for retries < maxRetries { + backoff := time.Second + + for retries := 0; retries < maxRetries; retries++ { select { case <-tm.stopCh: log.Println("Stopping Deferred task Retrier") return - case <-time.After(RetryInterval): + case <-time.After(backoff): tm.deferredTasks.ForEach(func(taskID string, task *task) bool { tm.send(task.ctx, task.nodeID, taskID, task.payload) - retries++ + backoff = backoff * 2 // Exponential backoff return true }) }