diff --git a/Makefile b/Makefile index 2b837a3..8eef144 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,3 @@ -run: lint - go run main.go - lint: golangci-lint run ./... diff --git a/README.md b/README.md index 98db783..55461b2 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,36 @@ # cycle-scheduler -cycle-scheduler is a simple scheduler handling tasks and executes them at regular interval. If a task is not in desired state, the task is re-scheduled with a backoff. +cycle-scheduler is a simple scheduler lib, handling tasks and executes them at regular interval. If a task is not in desired state, the task is re-scheduled. -## Run -You can run sample tests from `main.go` to see the scheduler in action: -```bash -make run -``` +**NOTE**: this should be not used for long-running tasks, it's more suitable for shorts tasks like polling etc... -You can adjust the clock interval and the number of workers as needed in `main.go` constants section: +## Examples +* Init a new scheduler with 4 workers ```go -const ( - MaxWorkers = 5 - Interval = 2000 * time.Millisecond +import ( + "context" + + scheduler "gitea.thegux.fr/rmanach/cycle-scheduler.git" ) + +ctx := context.Background() +s := scheduler.NewSchedulerCycle(ctx, 4) + +// add a task with an execution interval of 2 ms (executed every 2 ms) +// and a maximum duration of 30 second. +taskID := s.Delay( + func(ctx context.Context) (any, error) { + // ... + return any, nil + }, + scheduler.WithExecInterval(2*time.Millisecond), + scheduler.WithMaxDuration(30*time.Second) +) + +<-ctx.Done() +<-s.Done() ``` +**NOTE**: for `Delay` optionals arguments, check the `NewTask` method documentation for more details. + + diff --git a/go.mod b/go.mod index be15dc4..7b05801 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module cycle-scheduler +module gitea.thegux.fr/rmanach/cycle-scheduler.git go 1.22.4 diff --git a/internal/job/job.go b/internal/job/job.go deleted file mode 100644 index 781ffde..0000000 --- a/internal/job/job.go +++ /dev/null @@ -1,168 +0,0 @@ -package job - -import ( - "context" - "errors" - "sync" - "time" - - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -type State int - -const ( - Pending State = iota - Running - Success - Failed - Abort - Unknown -) - -const UnknownState = "unknown" - -var JobExecTimeout = 10 * time.Second - -func (s State) String() string { - switch s { - case Pending: - return "pending" - case Running: - return "running" - case Success: - return "success" - case Failed: - return "failed" - case Abort: - return "abort" - case Unknown: - return UnknownState - default: - return UnknownState - } -} - -var ( - ErrJobAborted = errors.New("job has been aborted") - ErrJobNotCompletedYet = errors.New("job is not right state, retrying") -) - -type FnJob func(ctx context.Context) error - -type JobDetails struct { - ID uuid.UUID `json:"id"` - State string `json:"state"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt *time.Time `json:"updatedAt,omitempty"` - Err string `json:"error"` -} - -type Job struct { - l sync.RWMutex - id uuid.UUID - createdAt time.Time - updatedAt *time.Time - state State - task FnJob - err error - chAbort chan struct{} -} - -func NewJob(task FnJob) Job { - return Job{ - id: uuid.New(), - createdAt: time.Now().UTC(), - state: Pending, - task: task, - chAbort: make(chan struct{}, 1), - } -} - -func (j *Job) IntoDetails() JobDetails { - j.l.RLock() - defer j.l.RUnlock() - - jd := JobDetails{ - ID: j.id, - CreatedAt: j.createdAt, - State: j.state.String(), - } - - if err := j.err; err != nil { - jd.Err = err.Error() - } - - if ut := j.updatedAt; ut != nil { - jd.UpdatedAt = ut - } - - return jd -} - -func (j *Job) GetID() uuid.UUID { - return j.id -} - -func (j *Job) GetState() State { - j.l.RLock() - defer j.l.RUnlock() - - return j.state -} - -func (j *Job) setState(s State) { - j.l.Lock() - defer j.l.Unlock() - - now := time.Now().UTC() - j.updatedAt = &now - - j.state = s -} - -func (j *Job) setFail(err error) { - j.l.Lock() - defer j.l.Unlock() - - now := time.Now().UTC() - j.updatedAt = &now - - if j.state != Abort { - j.state = Failed - } - - j.err = err -} - -func (j *Job) Abort() { - j.setState(Abort) - j.chAbort <- struct{}{} -} - -func (j *Job) Run(ctx context.Context) { - ctxExec, fnCancel := context.WithTimeout(ctx, JobExecTimeout) - defer fnCancel() - - j.setState(Running) - - log.Info().Str("job", j.GetID().String()).Msg("job running...") - - go func() { - for range j.chAbort { - j.setState(Abort) - fnCancel() - } - }() - - if err := j.task(ctxExec); err != nil { - if errors.Is(err, ErrJobNotCompletedYet) { - j.setState(Pending) - return - } - j.setFail(err) - return - } - j.setState(Success) -} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go deleted file mode 100644 index dfcd78a..0000000 --- a/internal/scheduler/scheduler.go +++ /dev/null @@ -1,163 +0,0 @@ -package scheduler - -import ( - "context" - "cycle-scheduler/internal/job" - "math" - "sync" - "time" - - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -const ExponentialFactor = 1.8 - -// SchedulerCycle is a simple scheduler handling jobs and executes them at regular interval. -// If a task is not in desired state, the task is re-scheduled with a backoff. -type SchedulerCycle struct { - wg sync.WaitGroup - - ctx context.Context - fnCancel context.CancelFunc - - interval time.Duration - tasks tasks - - chTasks chan *task -} - -func NewSchedulerCycle(ctx context.Context, interval time.Duration, workers uint32) *SchedulerCycle { - ctxChild, fnCancel := context.WithCancel(ctx) - - c := SchedulerCycle{ - wg: sync.WaitGroup{}, - ctx: ctxChild, - fnCancel: fnCancel, - interval: interval, - tasks: newTasks(), - chTasks: make(chan *task), - } - - c.run(workers) - - return &c -} - -func (c *SchedulerCycle) backoff(t *task) { - backoff := c.interval + time.Duration(math.Pow(ExponentialFactor, float64(t.attempts.Load()))) - - t.timer.set( - time.AfterFunc(backoff, func() { - select { - case c.chTasks <- t: - default: - log.Error().Str("task id", t.GetID().String()).Msg("unable to execute task to the worker, delayed it") - c.backoff(t) - } - }), - ) -} - -// exec runs the task now or if all the workers are in use, delayed it. -func (c *SchedulerCycle) exec(t *task) { - select { - case c.chTasks <- t: - default: - log.Error().Str("task id", t.GetID().String()).Msg("unable to execute the task to a worker now, delayed it") - c.backoff(t) - } -} - -func (c *SchedulerCycle) getTask(id uuid.UUID) *task { - return c.tasks.get(id) -} - -// run launches a number of worker to execute tasks. -// If a task returns `ErrJobNotCompletedYet`, it re-schedules with a backoff. -func (c *SchedulerCycle) run(n uint32) { - for i := 0; i < int(n); i++ { - c.wg.Add(1) - go func() { - defer c.wg.Done() - for { - select { - case t := <-c.chTasks: - c.execute(t, c.backoff) - case <-c.ctx.Done(): - log.Error().Msg("context done, worker is stopping...") - return - } - } - }() - } -} - -func (c *SchedulerCycle) execute(t *task, fnFallBack func(*task)) { - t.run(c.ctx) - if t.GetState() == job.Pending { - fnFallBack(t) - } -} - -func (c *SchedulerCycle) Stop() { - c.fnCancel() -} - -func (c *SchedulerCycle) Done() <-chan struct{} { - done := make(chan struct{}) - go func() { - <-c.ctx.Done() - c.wg.Wait() - done <- struct{}{} - }() - return done -} - -func (c *SchedulerCycle) Len() int { - return c.tasks.len() -} - -// TasksDone checks whether all the tasks has been completed. -func (c *SchedulerCycle) TasksDone() bool { - return c.tasks.completed() -} - -func (c *SchedulerCycle) GetTasksDetails() []TaskDetails { - return c.tasks.getAllDetails() -} - -// GetTaskDetails returns the task details by id. -func (c *SchedulerCycle) GetTaskDetails(id uuid.UUID) TaskDetails { - return c.tasks.getDetails(id) -} - -// Delay builds a task and add it to the scheduler engine. -func (c *SchedulerCycle) Delay(fnJob job.FnJob) uuid.UUID { - select { - case <-c.Done(): - log.Error().Msg("context done unable to add new job") - default: - } - - t := newTask(fnJob) - - c.tasks.add(t) - - c.exec(t) - - log.Info().Str("task", t.GetID().String()).Msg("task added successfully") - return t.GetID() -} - -// Abort aborts the task given by its id if it exists. -func (c *SchedulerCycle) Abort(id uuid.UUID) bool { - if t := c.getTask(id); t != nil { - t.abort() - - log.Info().Str("task id", t.GetID().String()).Msg("abort task done") - return true - } - - return false -} diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go deleted file mode 100644 index faa8c15..0000000 --- a/internal/scheduler/scheduler_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package scheduler - -import ( - "context" - "cycle-scheduler/internal/job" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestSlot(t *testing.T) { - ctx, fnCancel := context.WithCancel(context.Background()) - defer fnCancel() - - s := NewSchedulerCycle(ctx, 1*time.Millisecond, 5) - - s.Delay(func(ctx context.Context) error { - return nil - }) - s.Delay(func(ctx context.Context) error { - return job.ErrJobNotCompletedYet - }) - j3 := s.Delay(func(ctx context.Context) error { - return errors.New("errors") - }) - - time.Sleep(2 * time.Millisecond) - - assert.Equal(t, 3, s.Len()) - assert.Equal(t, job.Failed.String(), s.GetTaskDetails(j3).State) -} diff --git a/internal/scheduler/task.go b/internal/scheduler/task.go deleted file mode 100644 index 026ae32..0000000 --- a/internal/scheduler/task.go +++ /dev/null @@ -1,152 +0,0 @@ -package scheduler - -import ( - "context" - "cycle-scheduler/internal/job" - "sync" - "sync/atomic" - "time" - - "github.com/google/uuid" -) - -// atomicTimer wraps a `time.Timer`. -type atomicTimer struct { - atomic.Pointer[time.Timer] -} - -func (at *atomicTimer) stop() { - timer := at.Load() - if timer != nil { - timer.Stop() - } -} - -// set replaces the current timer. -// It also ensures that the current timer is stopped. -func (at *atomicTimer) set(t *time.Timer) { - timer := at.Load() - if timer != nil { - timer.Stop() - at.Swap(t) - return - } - - at.Swap(t) -} - -type TaskDetails struct { - job.JobDetails - Attempts int `json:"attempts"` -} - -type task struct { - *job.Job - attempts atomic.Uint32 - timer atomicTimer -} - -func newTask(f job.FnJob) *task { - j := job.NewJob(f) - t := task{ - Job: &j, - timer: atomicTimer{}, - } - - return &t -} - -func (t *task) abort() { - t.timer.stop() - t.Job.Abort() -} - -func (t *task) run(ctx context.Context) { - t.attempts.Add(1) - t.Job.Run(ctx) -} - -func (t *task) getDetails() TaskDetails { - return TaskDetails{ - JobDetails: t.IntoDetails(), - Attempts: int(t.attempts.Load()), - } -} - -type tasks struct { - l sync.RWMutex - s map[uuid.UUID]*task -} - -func newTasks() tasks { - return tasks{ - s: make(map[uuid.UUID]*task), - } -} - -func (ts *tasks) add(t *task) { - ts.l.Lock() - defer ts.l.Unlock() - - ts.s[t.GetID()] = t -} - -func (ts *tasks) get(id uuid.UUID) *task { - ts.l.RLock() - defer ts.l.RUnlock() - - j, ok := ts.s[id] - if !ok { - return nil - } - - return j -} - -func (ts *tasks) len() int { - ts.l.RLock() - defer ts.l.RUnlock() - - return len(ts.s) -} - -func (ts *tasks) completed() bool { - ts.l.RLock() - defer ts.l.RUnlock() - - for _, t := range ts.s { - if t.GetState() == job.Pending || t.GetState() == job.Running { - return false - } - } - - return true -} - -func (ts *tasks) getAllDetails() []TaskDetails { - ts.l.RLock() - defer ts.l.RUnlock() - - details := []TaskDetails{} - for _, t := range ts.s { - details = append(details, t.getDetails()) - } - - return details -} - -func (ts *tasks) getDetails(id uuid.UUID) TaskDetails { - ts.l.RLock() - defer ts.l.RUnlock() - - t, ok := ts.s[id] - if !ok { - return TaskDetails{ - JobDetails: job.JobDetails{ - State: job.UnknownState, - }, - } - } - - return t.getDetails() -} diff --git a/main.go b/main.go deleted file mode 100644 index 58b6ceb..0000000 --- a/main.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "context" - "cycle-scheduler/internal/job" - "cycle-scheduler/internal/scheduler" - "encoding/json" - "errors" - "fmt" - "math/rand/v2" - "os" - "os/signal" - "time" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -const ( - MaxWorkers = 5 - Interval = 2000 * time.Millisecond -) - -func initLogger() { - zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - log.Logger = log.With().Caller().Logger().Output(zerolog.ConsoleWriter{Out: os.Stderr}) -} - -func main() { - initLogger() - - ctx, stop := signal.NotifyContext( - context.Background(), - os.Interrupt, - os.Kill, - ) - defer stop() - - s := scheduler.NewSchedulerCycle(ctx, Interval, MaxWorkers) - - // pending test - for i := 0; i < 20; i++ { - go func(i int) { - time.Sleep(time.Duration(i) * time.Second) - s.Delay(func(ctx context.Context) error { - time.Sleep(4 * time.Second) //nolint:mnd // test purpose - if rand.IntN(10)%2 == 0 { //nolint:gosec,mnd // test prupose - return job.ErrJobNotCompletedYet - } - return nil - }) - }(i) - } - - // abort test - j := s.Delay(func(ctx context.Context) error { - time.Sleep(4 * time.Second) //nolint:mnd // test purpose - - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - return job.ErrJobNotCompletedYet - }) - go func() { - time.Sleep(2 * time.Second) //nolint:mnd // test purpose - s.Abort(j) - }() - - // abort test 2 - j2 := s.Delay(func(ctx context.Context) error { - time.Sleep(time.Second) - - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - return job.ErrJobNotCompletedYet - }) - go func() { - time.Sleep(10 * time.Second) //nolint:mnd // test purpose - s.Abort(j2) - }() - - // error test - s.Delay(func(ctx context.Context) error { - time.Sleep(5 * time.Second) //nolint:mnd // test purpose - return errors.New("err") - }) - - // success test - go func() { - time.Sleep(10 * time.Second) //nolint:mnd // test purpose - s.Delay(func(ctx context.Context) error { - time.Sleep(5 * time.Second) //nolint:mnd // test purpose - return nil - }) - }() - - go func() { - for { - time.Sleep(2 * time.Second) //nolint:mnd // test purpose - if s.TasksDone() { - s.Stop() - return - } - } - }() - - <-s.Done() - - ts := s.GetTasksDetails() - for _, t := range ts { - c, err := json.Marshal(&t) - if err != nil { - log.Err(err).Str("task", t.ID.String()).Msg("unable to parse task details into JSON") - continue - } - fmt.Println(string(c)) - } -} diff --git a/scheduler.go b/scheduler.go new file mode 100644 index 0000000..51c6171 --- /dev/null +++ b/scheduler.go @@ -0,0 +1,213 @@ +package scheduler + +import ( + "context" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/rs/zerolog/log" +) + +const ( + ChanLength = 500 + DefaultExecInterval = 30 * time.Second +) + +type IScheduler interface { + Delay(fnJob FnJob, opts ...TaskOption) uuid.UUID +} + +// SchedulerCycle is a simple scheduler handling jobs and executes them at regular interval. +// If a task is not in desired state, the task is re-scheduled. +type SchedulerCycle struct { + wg sync.WaitGroup + + ctx context.Context + fnCancel context.CancelFunc + + tasks tasks + + chTasks chan *task + chDone chan struct{} +} + +func NewSchedulerCycle(ctx context.Context, workers uint32) *SchedulerCycle { + ctxChild, fnCancel := context.WithCancel(ctx) + + c := SchedulerCycle{ + wg: sync.WaitGroup{}, + ctx: ctxChild, + fnCancel: fnCancel, + tasks: newTasks(), + chTasks: make(chan *task, ChanLength), + } + + done := make(chan struct{}) + go func() { + <-c.ctx.Done() + defer c.fnCancel() + c.wg.Wait() + c.stop() + done <- struct{}{} + }() + c.chDone = done + + c.run(workers) + + return &c +} + +// delay sets the task timer when the task should be scheduled. +func (c *SchedulerCycle) delay(t *task) { + interval := DefaultExecInterval + if t.execInterval != nil { + interval = *t.execInterval + } + + t.setTimer( + time.AfterFunc(interval, func() { + select { + case c.chTasks <- t: + default: + log.Warn().Str("id", t.GetID().String()).Msg("queue is full, can't accept new task, delayed it") + c.delay(t) + } + }), + ) +} + +// exec runs the task now or if all the workers are in use, delayed it. +func (c *SchedulerCycle) exec(t *task) { + select { + case c.chTasks <- t: + default: + log.Warn().Str("id", t.GetID().String()).Msg("queue is full, can't accept new task, delayed it") + c.delay(t) + } +} + +func (c *SchedulerCycle) getTask(id uuid.UUID) *task { + return c.tasks.get(id) +} + +// run launches a number of worker to execute tasks. +func (c *SchedulerCycle) run(n uint32) { + for i := 0; i < int(n); i++ { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case t := <-c.chTasks: + c.execute(t, c.delay) + case <-c.ctx.Done(): + log.Error().Msg("context done, worker is stopping...") + return + } + } + }() + } +} + +// execute executes the task. +// +// It does not handle task error, it's up to the task to implement its own callbacks. +// In case of pending state, a callback is executed for actions. For the others states, +// the task is deleted from the scheduler. +func (c *SchedulerCycle) execute(t *task, fnFallBack func(*task)) { + t.Run(c.ctx) + + switch t.GetState() { + case Pending: + fnFallBack(t) + case Success, Failed, Abort, Unknown: + c.tasks.delete(t) + case Running: + c.tasks.delete(t) + log.Debug().Str("id", t.GetID().String()).Msg("weird state (running) after job execution...") + } +} + +// stop aborts all tasks and waits until tasks are stopped. +// If the process can't be stopped within 10s, too bad... +func (c *SchedulerCycle) stop() { + c.tasks.abort() + + if c.TasksDone() { + log.Info().Msg("all tasks has been stopped gracefully") + return + } + + ctxTimeout := 10 * time.Second + ctx, fnCancel := context.WithTimeout(c.ctx, ctxTimeout) + defer fnCancel() + + for { + select { + case <-ctx.Done(): + log.Error().Msg("stop context done, tasks has been stopped gracefully") + return + default: + } + + if c.TasksDone() { + log.Info().Msg("all tasks has been stopped gracefully") + return + } + + time.Sleep(time.Second) + } +} + +func (c *SchedulerCycle) Done() <-chan struct{} { + return c.chDone +} + +func (c *SchedulerCycle) Len() int { + return c.tasks.len() +} + +// TasksDone checks whether all the tasks has been completed. +func (c *SchedulerCycle) TasksDone() bool { + return c.tasks.completed() +} + +func (c *SchedulerCycle) GetTasksDetails() []TaskDetails { + return c.tasks.getAllDetails() +} + +// GetTaskDetails returns the task details by id. +func (c *SchedulerCycle) GetTaskDetails(id uuid.UUID) TaskDetails { + return c.tasks.getDetails(id) +} + +// Delay builds a task and adds it to the scheduler engine. +func (c *SchedulerCycle) Delay(fnJob FnJob, opts ...TaskOption) uuid.UUID { + select { + case <-c.Done(): + log.Error().Msg("context done unable to add new job") + default: + } + + t := NewTask(fnJob, opts...) + + c.tasks.add(t) + c.exec(t) + + log.Info().Str("id", t.GetID().String()).Msg("task added successfully") + return t.GetID() +} + +// Abort aborts the task given by its id if it exists. +func (c *SchedulerCycle) Abort(id uuid.UUID) bool { + if t := c.getTask(id); t != nil { + t.Abort() + + log.Info().Str("id", t.GetID().String()).Msg("abort task done") + return true + } + + return false +} diff --git a/scheduler_test.go b/scheduler_test.go new file mode 100644 index 0000000..21ab991 --- /dev/null +++ b/scheduler_test.go @@ -0,0 +1,93 @@ +package scheduler + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestScheduler(t *testing.T) { + ctx := context.Background() + + var buf int + + s := NewSchedulerCycle(ctx, 1) + taskID := s.Delay(func(ctx context.Context) (any, error) { + time.Sleep(50 * time.Millisecond) + buf += 1 + return nil, nil + }, WithExecInterval(2*time.Millisecond)) + + assert.NotEmpty(t, taskID) + assert.False(t, s.TasksDone()) + + time.Sleep(2 * time.Millisecond) + + details := s.GetTaskDetails(taskID) + assert.Equal(t, "running", details.State) + assert.LessOrEqual(t, details.ElapsedTime, 50*time.Millisecond) + + time.Sleep(50 * time.Millisecond) + + assert.True(t, s.TasksDone()) +} + +func TestSchedulerLoad(t *testing.T) { + ctx := context.Background() + + s := NewSchedulerCycle(ctx, 1) + + for i := 0; i < 500; i++ { + s.Delay(func(ctx context.Context) (any, error) { + time.Sleep(1 * time.Millisecond) + return nil, nil + }, WithExecInterval(1*time.Millisecond)) + } + + assert.Eventually(t, func() bool { + return s.TasksDone() + }, time.Second, 250*time.Millisecond) +} + +func TestSchedulerExecInterval(t *testing.T) { + ctx := context.Background() + + s := NewSchedulerCycle(ctx, 1) + + s.Delay( + func(ctx context.Context) (any, error) { + return nil, ErrJobNotCompletedYet + }, + WithMaxDuration(50*time.Millisecond), + WithExecInterval(2*time.Millisecond), + ) + + time.Sleep(100 * time.Millisecond) + + assert.True(t, s.TasksDone()) +} + +func TestSchedulerContextDone(t *testing.T) { + ctx, fnCancel := context.WithCancel(context.Background()) + + s := NewSchedulerCycle(ctx, 1) + + for i := 0; i < 250; i++ { + s.Delay( + func(ctx context.Context) (any, error) { + return nil, ErrJobNotCompletedYet + }, + WithMaxDuration(100*time.Millisecond), + WithExecInterval(2*time.Millisecond), + ) + } + + go func() { + time.Sleep(50 * time.Millisecond) + fnCancel() + }() + + <-s.Done() +} diff --git a/task.go b/task.go new file mode 100644 index 0000000..d6058b0 --- /dev/null +++ b/task.go @@ -0,0 +1,451 @@ +package scheduler + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/rs/zerolog/log" +) + +type State int + +const ( + Pending State = iota + Running + Success + Failed + Abort + Unknown +) + +const UnknownState = "unknown" + +func (s State) String() string { + return [...]string{"pending", "running", "success", "failed", "abort", "unknown"}[s] +} + +var ( + ErrJobAborted = errors.New("job has been aborted") + ErrJobNotCompletedYet = errors.New("job is not right state, retrying") + ErrTimeExceeded = errors.New("time exceeded") + ErrExecutionEngine = errors.New("execution engine") +) + +type FnJob func(ctx context.Context) (any, error) + +type FnResult func(ctx context.Context, res any) +type FnError func(ctx context.Context, err error) + +type TaskDetails struct { + State string `json:"state"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt *time.Time `json:"updatedAt,omitempty"` + MaxDuration *time.Duration `json:"maxDuration,omitempty"` + Attempts uint32 `json:"attempts"` + Err string `json:"error"` + ElapsedTime time.Duration `json:"elapsedTime"` + AdditionalInfo map[string]string `json:"additionalInfos"` +} + +func (td *TaskDetails) log() { + args := []any{} + args = append(args, + "state", td.State, + "createdAt", td.CreatedAt.Format("2006-01-02T15:04:05Z"), + "updatedAt", td.UpdatedAt.Format("2006-01-02T15:04:05Z"), + "elapsedTime", td.ElapsedTime.String(), + "attempts", td.Attempts, + ) + + if td.AdditionalInfo != nil { + for k, v := range td.AdditionalInfo { + args = append(args, k, v) + } + } + + if td.Err != "" { + args = append(args, "error", td.Err) + log.Error().Any("args", args).Msg("task failed") + return + } + + log.Info().Any("args", args).Msg("task execution done") +} + +type task struct { + l sync.RWMutex + + id uuid.UUID + createdAt time.Time + updatedAt *time.Time + state State + + fnJob FnJob + fnSuccess FnResult + fnError FnError + + attempts uint32 + timer *time.Timer + maxDuration *time.Duration + execTimeout *time.Duration + execInterval *time.Duration + + res any + err error + + additionalInfos map[string]string + + chAbort chan struct{} +} + +type TaskOption func(t *task) + +func WithMaxDuration(d time.Duration) TaskOption { + return func(t *task) { + t.maxDuration = &d + } +} + +func WithFnSuccess(f FnResult) TaskOption { + return func(t *task) { + t.fnSuccess = f + } +} + +func WithFnError(f FnError) TaskOption { + return func(t *task) { + t.fnError = f + } +} + +func WithExecTimeout(d time.Duration) TaskOption { + return func(t *task) { + t.execTimeout = &d + } +} + +func WithExecInterval(d time.Duration) TaskOption { + return func(t *task) { + t.execInterval = &d + } +} + +func WithAdditionalInfos(k, v string) TaskOption { + return func(t *task) { + if k == "" || v == "" { + return + } + + if t.additionalInfos == nil { + t.additionalInfos = map[string]string{} + } + + t.additionalInfos[k] = v + } +} + +// NewTask builds a task. +// Here the options details that can be set to the task: +// - WithMaxDuration(time.Duration): the task will stop executing if the duration is exceeded (raise ErrTimeExceeded) +// - WithFnSuccess(FnResult): call a function after a success execution +// - WithFnError(FnError): call a function if an error occurred +// - WithExecTimeout(time.Duration): sets a timeout for the task execution +// - WithAdditionalInfos(k, v string): key-value additional informations (does not inferred with the task execution, log purpose) +// +// Scheduler options: +// - WithExecInterval(time.Duration): specify the execution interval +func NewTask(f FnJob, opts ...TaskOption) *task { + t := task{ + id: uuid.New(), + createdAt: time.Now().UTC(), + state: Pending, + fnJob: f, + chAbort: make(chan struct{}, 1), + } + + for _, o := range opts { + o(&t) + } + + return &t +} + +func (t *task) setState(s State) { + t.l.Lock() + defer t.l.Unlock() + + now := time.Now().UTC() + t.updatedAt = &now + + t.state = s +} + +func (t *task) setSuccess(ctx context.Context, res any) { + t.l.Lock() + defer t.l.Unlock() + + now := time.Now().UTC() + t.updatedAt = &now + + t.state = Success + t.res = res + + if t.fnSuccess != nil { + t.fnSuccess(ctx, res) + } +} + +func (t *task) setFail(ctx context.Context, err error) { + t.l.Lock() + defer t.l.Unlock() + + now := time.Now().UTC() + t.updatedAt = &now + + if t.state != Abort { + t.state = Failed + } + + t.err = err + + if t.fnError != nil { + t.fnError(ctx, err) + } +} + +func (t *task) setTimer(tm *time.Timer) { + t.l.Lock() + defer t.l.Unlock() + + t.timer = tm +} + +func (t *task) stopTimer() { + t.l.Lock() + defer t.l.Unlock() + + if t.timer != nil { + t.timer.Stop() + } +} + +func (t *task) incr() { + t.l.Lock() + defer t.l.Unlock() + + t.attempts += 1 +} + +func (t *task) log() { + td := t.IntoDetails() + td.log() +} + +func (t *task) GetID() uuid.UUID { + t.l.RLock() + defer t.l.RUnlock() + + return t.id +} + +func (t *task) GetAttempts() uint32 { + t.l.RLock() + defer t.l.RUnlock() + + return t.attempts +} + +func (t *task) GetState() State { + t.l.RLock() + defer t.l.RUnlock() + + return t.state +} + +// TimeExceeded checks if the task does not reached its max duration execution (maxDuration). +func (t *task) TimeExceeded() bool { + t.l.RLock() + defer t.l.RUnlock() + + if md := t.maxDuration; md != nil { + return time.Since(t.createdAt) >= *md + } + + return false +} + +func (t *task) Abort() { + t.stopTimer() + t.chAbort <- struct{}{} +} + +func (t *task) Run(ctx context.Context) { + if t.TimeExceeded() { + t.setFail(ctx, ErrTimeExceeded) + return + } + + if s := t.GetState(); s != Pending { + log.Error().Msg("unable to launch a task that not in pending state") + t.setFail(ctx, ErrExecutionEngine) + return + } + + var ctxExec context.Context + var fnCancel context.CancelFunc + if t.execTimeout != nil { + ctxExec, fnCancel = context.WithTimeout(ctx, *t.execTimeout) + } else { + ctxExec, fnCancel = context.WithCancel(ctx) + } + defer fnCancel() + + t.incr() + t.setState(Running) + + log.Info().Str("id", t.GetID().String()).Msg("task is running...") + + go func() { + for range t.chAbort { + t.setState(Abort) + fnCancel() + return + } + }() + + defer t.log() + res, err := t.fnJob(ctxExec) + if err != nil { + if errors.Is(err, ErrJobNotCompletedYet) { + t.setState(Pending) + return + } + t.setFail(ctx, err) + return + } + + t.setSuccess(ctx, res) +} + +func (t *task) IntoDetails() TaskDetails { + t.l.RLock() + defer t.l.RUnlock() + + td := TaskDetails{ + CreatedAt: t.createdAt, + UpdatedAt: t.updatedAt, + State: t.state.String(), + Attempts: t.attempts, + MaxDuration: t.maxDuration, + } + + if t.state == Pending || t.state == Running { + td.ElapsedTime = time.Since(t.createdAt) + } else { + td.ElapsedTime = t.updatedAt.Sub(t.createdAt) + } + + if err := t.err; err != nil { + td.Err = err.Error() + } + + if t.additionalInfos != nil { + td.AdditionalInfo = t.additionalInfos + } + + return td +} + +type tasks struct { + l sync.RWMutex + s map[uuid.UUID]*task +} + +func newTasks() tasks { + return tasks{ + s: make(map[uuid.UUID]*task), + } +} + +func (ts *tasks) add(t *task) { + ts.l.Lock() + defer ts.l.Unlock() + + ts.s[t.GetID()] = t +} + +func (ts *tasks) delete(t *task) { + ts.l.Lock() + defer ts.l.Unlock() + + delete(ts.s, t.GetID()) +} + +func (ts *tasks) get(id uuid.UUID) *task { + ts.l.RLock() + defer ts.l.RUnlock() + + t, ok := ts.s[id] + if !ok { + return nil + } + + return t +} + +func (ts *tasks) len() int { + ts.l.RLock() + defer ts.l.RUnlock() + + return len(ts.s) +} + +func (ts *tasks) completed() bool { + ts.l.RLock() + defer ts.l.RUnlock() + + for _, t := range ts.s { + if t.GetState() == Pending || t.GetState() == Running { + return false + } + } + + return true +} + +func (ts *tasks) getAllDetails() []TaskDetails { + ts.l.RLock() + defer ts.l.RUnlock() + + details := []TaskDetails{} + for _, t := range ts.s { + details = append(details, t.IntoDetails()) + } + + return details +} + +func (ts *tasks) getDetails(id uuid.UUID) TaskDetails { + ts.l.RLock() + defer ts.l.RUnlock() + + t, ok := ts.s[id] + if !ok { + return TaskDetails{State: UnknownState} + } + + return t.IntoDetails() +} + +func (ts *tasks) abort() { + ts.l.RLock() + defer ts.l.RUnlock() + + for _, t := range ts.s { + t.Abort() + } +} diff --git a/task_test.go b/task_test.go new file mode 100644 index 0000000..63aeb66 --- /dev/null +++ b/task_test.go @@ -0,0 +1,243 @@ +package scheduler + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestTask(t *testing.T) { + ctx := context.Background() + var i int + + task := NewTask( + func(ctx context.Context) (any, error) { + i += 1 + return nil, nil + }, + ) + + task.Run(ctx) + + assert.Equal(t, nil, task.res) + assert.NotEmpty(t, task.updatedAt) + assert.Equal(t, 1, i) + assert.Equal(t, 1, int(task.GetAttempts())) +} + +func TestAbortTask(t *testing.T) { + ctx := context.Background() + + task := NewTask( + func(ctx context.Context) (any, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + ) + + timer := time.NewTicker(200 * time.Millisecond) + go func() { + <-timer.C + task.Abort() + }() + task.Run(ctx) + + assert.Equal(t, Abort, task.GetState()) +} + +func TestTaskContextDone(t *testing.T) { + ctx, fnCancel := context.WithCancel(context.Background()) + + task := NewTask( + func(ctx context.Context) (any, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + ) + + timer := time.NewTicker(200 * time.Millisecond) + go func() { + <-timer.C + fnCancel() + }() + task.Run(ctx) + + assert.Equal(t, Failed, task.GetState()) +} + +func TestTaskFnSuccess(t *testing.T) { + ctx := context.Background() + var result int + + task := NewTask( + func(ctx context.Context) (any, error) { + return 3, nil + }, + WithFnSuccess(func(ctx context.Context, res any) { + t, ok := res.(int) + if !ok { + return + } + + result = t * 2 + }), + ) + + task.Run(ctx) + + assert.Equal(t, Success, task.GetState()) + assert.Equal(t, 6, result) +} + +func TestTaskFnError(t *testing.T) { + ctx := context.Background() + var result error + + task := NewTask( + func(ctx context.Context) (any, error) { + return 3, errors.New("error occurred...") + }, + WithFnError(func(ctx context.Context, err error) { + result = err + }), + ) + + task.Run(ctx) + + assert.Equal(t, Failed, task.GetState()) + assert.Equal(t, "error occurred...", result.Error()) +} + +func TestTaskWithErrJobNotCompletedYet(t *testing.T) { + ctx := context.Background() + var attempts int + + task := NewTask( + func(ctx context.Context) (any, error) { + if attempts < 2 { + attempts += 1 + return nil, ErrJobNotCompletedYet + } + return "ok", nil + }, + ) + + for i := 0; i < 2; i++ { + task.Run(ctx) + assert.Equal(t, Pending, task.GetState()) + } + + task.Run(ctx) + assert.Equal(t, Success, task.GetState()) + assert.Equal(t, 3, int(task.GetAttempts())) +} + +func TestTaskTimeExceeded(t *testing.T) { + ctx := context.Background() + + task := NewTask( + func(ctx context.Context) (any, error) { + return "ko", nil + }, + WithMaxDuration(5*time.Millisecond), + ) + + time.Sleep(10 * time.Millisecond) + + task.Run(ctx) + assert.Equal(t, Failed, task.GetState()) + assert.Equal(t, 0, int(task.GetAttempts())) +} + +func TestTaskExecTimeout(t *testing.T) { + ctx := context.Background() + + task := NewTask( + func(ctx context.Context) (any, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + WithExecTimeout(5*time.Millisecond), + ) + + time.Sleep(10 * time.Millisecond) + + task.Run(ctx) + assert.Equal(t, Failed, task.GetState()) + assert.Equal(t, 1, int(task.GetAttempts())) +} + +func TestTaskDetails(t *testing.T) { + ctx := context.Background() + + task := NewTask( + func(ctx context.Context) (any, error) { + return "coucou", nil + }, + ) + + details := task.IntoDetails() + + assert.Equal(t, 0, int(details.Attempts)) + assert.Equal(t, "pending", details.State) + assert.False(t, details.CreatedAt.IsZero()) + assert.Empty(t, details.UpdatedAt) + assert.Nil(t, details.MaxDuration) + assert.Empty(t, details.Err) + assert.NotEmpty(t, details.ElapsedTime) + + task.Run(ctx) + + details = task.IntoDetails() + + assert.Equal(t, 1, int(details.Attempts)) + assert.Equal(t, "success", details.State) + assert.False(t, details.CreatedAt.IsZero()) + assert.NotEmpty(t, details.UpdatedAt) + assert.Nil(t, details.MaxDuration) + assert.Empty(t, details.Err) + assert.NotEmpty(t, details.ElapsedTime) +} + +func TestTaskAdditionalInfos(t *testing.T) { + t.Run("with key value", func(t *testing.T) { + elementID := uuid.NewString() + task := NewTask( + func(ctx context.Context) (any, error) { + return "yo", nil + }, + WithAdditionalInfos("transportId", elementID), + WithAdditionalInfos("element", "transport"), + ) + + assert.Equal(t, elementID, task.additionalInfos["transportId"]) + assert.Equal(t, "transport", task.additionalInfos["element"]) + }) + + t.Run("with empty key", func(t *testing.T) { + elementID := uuid.NewString() + task := NewTask( + func(ctx context.Context) (any, error) { + return "hello", nil + }, + WithAdditionalInfos("", elementID), + WithAdditionalInfos("element", "transport"), + ) + + assert.Equal(t, "transport", task.additionalInfos["element"]) + }) + + t.Run("with empty infos", func(t *testing.T) { + task := NewTask( + func(ctx context.Context) (any, error) { + return "hey", nil + }, + ) + + assert.Nil(t, task.additionalInfos) + }) +}