turn into lib

This commit is contained in:
rmanach 2024-11-16 11:00:46 +01:00
parent 3f1afb63d4
commit 2a0ecf2d4e
11 changed files with 1024 additions and 654 deletions

View File

@ -1,6 +1,3 @@
run: lint
go run main.go
lint:
golangci-lint run ./...

View File

@ -1,18 +1,32 @@
# cycle-scheduler
cycle-scheduler is a simple scheduler handling tasks and executes them at regular interval. If a task is not in desired state, the task is re-scheduled with a backoff.
cycle-scheduler is a simple scheduler lib, handling tasks and executes them at regular interval. If a task is not in desired state, the task is re-scheduled with a backoff.
## Run
You can run sample tests from `main.go` to see the scheduler in action:
```bash
make run
```
**NOTE**: this should be not used for long-running tasks, it's more suitable for shorts tasks like polling etc...
You can adjust the clock interval and the number of workers as needed in `main.go` constants section:
## Examples
* Init a new scheduler with 4 workers
```go
const (
MaxWorkers = 5
Interval = 2000 * time.Millisecond
import (
"context"
)
ctx := context.Background()
s := NewSchedulerCycle(ctx, 4)
// add a task
taskID := s.Delay(
func(ctx context.Context) (any, error) {
// execution
return any, nil
},
WithExecInterval(2*time.Millisecond)
)
<-ctx.Done()
<-s.Done()
```
**NOTE**: for `Delay` optionals arguments, check the `NewTask` method documentation for more details.

View File

@ -1,168 +0,0 @@
package job
import (
"context"
"errors"
"sync"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type State int
const (
Pending State = iota
Running
Success
Failed
Abort
Unknown
)
const UnknownState = "unknown"
var JobExecTimeout = 10 * time.Second
func (s State) String() string {
switch s {
case Pending:
return "pending"
case Running:
return "running"
case Success:
return "success"
case Failed:
return "failed"
case Abort:
return "abort"
case Unknown:
return UnknownState
default:
return UnknownState
}
}
var (
ErrJobAborted = errors.New("job has been aborted")
ErrJobNotCompletedYet = errors.New("job is not right state, retrying")
)
type FnJob func(ctx context.Context) error
type JobDetails struct {
ID uuid.UUID `json:"id"`
State string `json:"state"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
Err string `json:"error"`
}
type Job struct {
l sync.RWMutex
id uuid.UUID
createdAt time.Time
updatedAt *time.Time
state State
task FnJob
err error
chAbort chan struct{}
}
func NewJob(task FnJob) Job {
return Job{
id: uuid.New(),
createdAt: time.Now().UTC(),
state: Pending,
task: task,
chAbort: make(chan struct{}, 1),
}
}
func (j *Job) IntoDetails() JobDetails {
j.l.RLock()
defer j.l.RUnlock()
jd := JobDetails{
ID: j.id,
CreatedAt: j.createdAt,
State: j.state.String(),
}
if err := j.err; err != nil {
jd.Err = err.Error()
}
if ut := j.updatedAt; ut != nil {
jd.UpdatedAt = ut
}
return jd
}
func (j *Job) GetID() uuid.UUID {
return j.id
}
func (j *Job) GetState() State {
j.l.RLock()
defer j.l.RUnlock()
return j.state
}
func (j *Job) setState(s State) {
j.l.Lock()
defer j.l.Unlock()
now := time.Now().UTC()
j.updatedAt = &now
j.state = s
}
func (j *Job) setFail(err error) {
j.l.Lock()
defer j.l.Unlock()
now := time.Now().UTC()
j.updatedAt = &now
if j.state != Abort {
j.state = Failed
}
j.err = err
}
func (j *Job) Abort() {
j.setState(Abort)
j.chAbort <- struct{}{}
}
func (j *Job) Run(ctx context.Context) {
ctxExec, fnCancel := context.WithTimeout(ctx, JobExecTimeout)
defer fnCancel()
j.setState(Running)
log.Info().Str("job", j.GetID().String()).Msg("job running...")
go func() {
for range j.chAbort {
j.setState(Abort)
fnCancel()
}
}()
if err := j.task(ctxExec); err != nil {
if errors.Is(err, ErrJobNotCompletedYet) {
j.setState(Pending)
return
}
j.setFail(err)
return
}
j.setState(Success)
}

View File

@ -1,163 +0,0 @@
package scheduler
import (
"context"
"cycle-scheduler/internal/job"
"math"
"sync"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
const ExponentialFactor = 1.8
// SchedulerCycle is a simple scheduler handling jobs and executes them at regular interval.
// If a task is not in desired state, the task is re-scheduled with a backoff.
type SchedulerCycle struct {
wg sync.WaitGroup
ctx context.Context
fnCancel context.CancelFunc
interval time.Duration
tasks tasks
chTasks chan *task
}
func NewSchedulerCycle(ctx context.Context, interval time.Duration, workers uint32) *SchedulerCycle {
ctxChild, fnCancel := context.WithCancel(ctx)
c := SchedulerCycle{
wg: sync.WaitGroup{},
ctx: ctxChild,
fnCancel: fnCancel,
interval: interval,
tasks: newTasks(),
chTasks: make(chan *task),
}
c.run(workers)
return &c
}
func (c *SchedulerCycle) backoff(t *task) {
backoff := c.interval + time.Duration(math.Pow(ExponentialFactor, float64(t.attempts.Load())))
t.timer.set(
time.AfterFunc(backoff, func() {
select {
case c.chTasks <- t:
default:
log.Error().Str("task id", t.GetID().String()).Msg("unable to execute task to the worker, delayed it")
c.backoff(t)
}
}),
)
}
// exec runs the task now or if all the workers are in use, delayed it.
func (c *SchedulerCycle) exec(t *task) {
select {
case c.chTasks <- t:
default:
log.Error().Str("task id", t.GetID().String()).Msg("unable to execute the task to a worker now, delayed it")
c.backoff(t)
}
}
func (c *SchedulerCycle) getTask(id uuid.UUID) *task {
return c.tasks.get(id)
}
// run launches a number of worker to execute tasks.
// If a task returns `ErrJobNotCompletedYet`, it re-schedules with a backoff.
func (c *SchedulerCycle) run(n uint32) {
for i := 0; i < int(n); i++ {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case t := <-c.chTasks:
c.execute(t, c.backoff)
case <-c.ctx.Done():
log.Error().Msg("context done, worker is stopping...")
return
}
}
}()
}
}
func (c *SchedulerCycle) execute(t *task, fnFallBack func(*task)) {
t.run(c.ctx)
if t.GetState() == job.Pending {
fnFallBack(t)
}
}
func (c *SchedulerCycle) Stop() {
c.fnCancel()
}
func (c *SchedulerCycle) Done() <-chan struct{} {
done := make(chan struct{})
go func() {
<-c.ctx.Done()
c.wg.Wait()
done <- struct{}{}
}()
return done
}
func (c *SchedulerCycle) Len() int {
return c.tasks.len()
}
// TasksDone checks whether all the tasks has been completed.
func (c *SchedulerCycle) TasksDone() bool {
return c.tasks.completed()
}
func (c *SchedulerCycle) GetTasksDetails() []TaskDetails {
return c.tasks.getAllDetails()
}
// GetTaskDetails returns the task details by id.
func (c *SchedulerCycle) GetTaskDetails(id uuid.UUID) TaskDetails {
return c.tasks.getDetails(id)
}
// Delay builds a task and add it to the scheduler engine.
func (c *SchedulerCycle) Delay(fnJob job.FnJob) uuid.UUID {
select {
case <-c.Done():
log.Error().Msg("context done unable to add new job")
default:
}
t := newTask(fnJob)
c.tasks.add(t)
c.exec(t)
log.Info().Str("task", t.GetID().String()).Msg("task added successfully")
return t.GetID()
}
// Abort aborts the task given by its id if it exists.
func (c *SchedulerCycle) Abort(id uuid.UUID) bool {
if t := c.getTask(id); t != nil {
t.abort()
log.Info().Str("task id", t.GetID().String()).Msg("abort task done")
return true
}
return false
}

View File

@ -1,33 +0,0 @@
package scheduler
import (
"context"
"cycle-scheduler/internal/job"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSlot(t *testing.T) {
ctx, fnCancel := context.WithCancel(context.Background())
defer fnCancel()
s := NewSchedulerCycle(ctx, 1*time.Millisecond, 5)
s.Delay(func(ctx context.Context) error {
return nil
})
s.Delay(func(ctx context.Context) error {
return job.ErrJobNotCompletedYet
})
j3 := s.Delay(func(ctx context.Context) error {
return errors.New("errors")
})
time.Sleep(2 * time.Millisecond)
assert.Equal(t, 3, s.Len())
assert.Equal(t, job.Failed.String(), s.GetTaskDetails(j3).State)
}

View File

@ -1,152 +0,0 @@
package scheduler
import (
"context"
"cycle-scheduler/internal/job"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
)
// atomicTimer wraps a `time.Timer`.
type atomicTimer struct {
atomic.Pointer[time.Timer]
}
func (at *atomicTimer) stop() {
timer := at.Load()
if timer != nil {
timer.Stop()
}
}
// set replaces the current timer.
// It also ensures that the current timer is stopped.
func (at *atomicTimer) set(t *time.Timer) {
timer := at.Load()
if timer != nil {
timer.Stop()
at.Swap(t)
return
}
at.Swap(t)
}
type TaskDetails struct {
job.JobDetails
Attempts int `json:"attempts"`
}
type task struct {
*job.Job
attempts atomic.Uint32
timer atomicTimer
}
func newTask(f job.FnJob) *task {
j := job.NewJob(f)
t := task{
Job: &j,
timer: atomicTimer{},
}
return &t
}
func (t *task) abort() {
t.timer.stop()
t.Job.Abort()
}
func (t *task) run(ctx context.Context) {
t.attempts.Add(1)
t.Job.Run(ctx)
}
func (t *task) getDetails() TaskDetails {
return TaskDetails{
JobDetails: t.IntoDetails(),
Attempts: int(t.attempts.Load()),
}
}
type tasks struct {
l sync.RWMutex
s map[uuid.UUID]*task
}
func newTasks() tasks {
return tasks{
s: make(map[uuid.UUID]*task),
}
}
func (ts *tasks) add(t *task) {
ts.l.Lock()
defer ts.l.Unlock()
ts.s[t.GetID()] = t
}
func (ts *tasks) get(id uuid.UUID) *task {
ts.l.RLock()
defer ts.l.RUnlock()
j, ok := ts.s[id]
if !ok {
return nil
}
return j
}
func (ts *tasks) len() int {
ts.l.RLock()
defer ts.l.RUnlock()
return len(ts.s)
}
func (ts *tasks) completed() bool {
ts.l.RLock()
defer ts.l.RUnlock()
for _, t := range ts.s {
if t.GetState() == job.Pending || t.GetState() == job.Running {
return false
}
}
return true
}
func (ts *tasks) getAllDetails() []TaskDetails {
ts.l.RLock()
defer ts.l.RUnlock()
details := []TaskDetails{}
for _, t := range ts.s {
details = append(details, t.getDetails())
}
return details
}
func (ts *tasks) getDetails(id uuid.UUID) TaskDetails {
ts.l.RLock()
defer ts.l.RUnlock()
t, ok := ts.s[id]
if !ok {
return TaskDetails{
JobDetails: job.JobDetails{
State: job.UnknownState,
},
}
}
return t.getDetails()
}

125
main.go
View File

@ -1,125 +0,0 @@
package main
import (
"context"
"cycle-scheduler/internal/job"
"cycle-scheduler/internal/scheduler"
"encoding/json"
"errors"
"fmt"
"math/rand/v2"
"os"
"os/signal"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
const (
MaxWorkers = 5
Interval = 2000 * time.Millisecond
)
func initLogger() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.With().Caller().Logger().Output(zerolog.ConsoleWriter{Out: os.Stderr})
}
func main() {
initLogger()
ctx, stop := signal.NotifyContext(
context.Background(),
os.Interrupt,
os.Kill,
)
defer stop()
s := scheduler.NewSchedulerCycle(ctx, Interval, MaxWorkers)
// pending test
for i := 0; i < 20; i++ {
go func(i int) {
time.Sleep(time.Duration(i) * time.Second)
s.Delay(func(ctx context.Context) error {
time.Sleep(4 * time.Second) //nolint:mnd // test purpose
if rand.IntN(10)%2 == 0 { //nolint:gosec,mnd // test prupose
return job.ErrJobNotCompletedYet
}
return nil
})
}(i)
}
// abort test
j := s.Delay(func(ctx context.Context) error {
time.Sleep(4 * time.Second) //nolint:mnd // test purpose
select {
case <-ctx.Done():
return ctx.Err()
default:
}
return job.ErrJobNotCompletedYet
})
go func() {
time.Sleep(2 * time.Second) //nolint:mnd // test purpose
s.Abort(j)
}()
// abort test 2
j2 := s.Delay(func(ctx context.Context) error {
time.Sleep(time.Second)
select {
case <-ctx.Done():
return ctx.Err()
default:
}
return job.ErrJobNotCompletedYet
})
go func() {
time.Sleep(10 * time.Second) //nolint:mnd // test purpose
s.Abort(j2)
}()
// error test
s.Delay(func(ctx context.Context) error {
time.Sleep(5 * time.Second) //nolint:mnd // test purpose
return errors.New("err")
})
// success test
go func() {
time.Sleep(10 * time.Second) //nolint:mnd // test purpose
s.Delay(func(ctx context.Context) error {
time.Sleep(5 * time.Second) //nolint:mnd // test purpose
return nil
})
}()
go func() {
for {
time.Sleep(2 * time.Second) //nolint:mnd // test purpose
if s.TasksDone() {
s.Stop()
return
}
}
}()
<-s.Done()
ts := s.GetTasksDetails()
for _, t := range ts {
c, err := json.Marshal(&t)
if err != nil {
log.Err(err).Str("task", t.ID.String()).Msg("unable to parse task details into JSON")
continue
}
fmt.Println(string(c))
}
}

213
scheduler.go Normal file
View File

@ -0,0 +1,213 @@
package scheduler
import (
"context"
"sync"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
const (
ChanLength = 500
DefaultExecInterval = 30 * time.Second
)
type IScheduler interface {
Delay(fnJob FnJob, opts ...TaskOption) uuid.UUID
}
// SchedulerCycle is a simple scheduler handling jobs and executes them at regular interval.
// If a task is not in desired state, the task is re-scheduled with a backoff.
type SchedulerCycle struct {
wg sync.WaitGroup
ctx context.Context
fnCancel context.CancelFunc
tasks tasks
chTasks chan *task
chDone chan struct{}
}
func NewSchedulerCycle(ctx context.Context, workers uint32) *SchedulerCycle {
ctxChild, fnCancel := context.WithCancel(ctx)
c := SchedulerCycle{
wg: sync.WaitGroup{},
ctx: ctxChild,
fnCancel: fnCancel,
tasks: newTasks(),
chTasks: make(chan *task, ChanLength),
}
done := make(chan struct{})
go func() {
<-c.ctx.Done()
defer c.fnCancel()
c.wg.Wait()
c.stop()
done <- struct{}{}
}()
c.chDone = done
c.run(workers)
return &c
}
// delay sets the task timer when the task should be scheduled.
func (c *SchedulerCycle) delay(t *task) {
interval := DefaultExecInterval
if t.execInterval != nil {
interval = *t.execInterval
}
t.setTimer(
time.AfterFunc(interval, func() {
select {
case c.chTasks <- t:
default:
log.Warn().Str("id", t.GetID().String()).Msg("queue is full, can't accept new task, delayed it")
c.delay(t)
}
}),
)
}
// exec runs the task now or if all the workers are in use, delayed it.
func (c *SchedulerCycle) exec(t *task) {
select {
case c.chTasks <- t:
default:
log.Warn().Str("id", t.GetID().String()).Msg("queue is full, can't accept new task, delayed it")
c.delay(t)
}
}
func (c *SchedulerCycle) getTask(id uuid.UUID) *task {
return c.tasks.get(id)
}
// run launches a number of worker to execute tasks.
func (c *SchedulerCycle) run(n uint32) {
for i := 0; i < int(n); i++ {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case t := <-c.chTasks:
c.execute(t, c.delay)
case <-c.ctx.Done():
log.Error().Msg("context done, worker is stopping...")
return
}
}
}()
}
}
// execute executes the task.
//
// It does not handle task error, it's up to the task to implement its own callbacks.
// In case of pending state, a callback is executed for actions. For the others states,
// the task is deleted from the scheduler.
func (c *SchedulerCycle) execute(t *task, fnFallBack func(*task)) {
t.Run(c.ctx)
switch t.GetState() {
case Pending:
fnFallBack(t)
case Success, Failed, Abort, Unknown:
c.tasks.delete(t)
case Running:
c.tasks.delete(t)
log.Debug().Str("id", t.GetID().String()).Msg("weird state (running) after job execution...")
}
}
// stop aborts all tasks and waits until tasks are stopped.
// If the process can't be stopped within 10s, too bad...
func (c *SchedulerCycle) stop() {
c.tasks.abort()
if c.TasksDone() {
log.Info().Msg("all tasks has been stopped gracefully")
return
}
ctxTimeout := 10 * time.Second
ctx, fnCancel := context.WithTimeout(c.ctx, ctxTimeout)
defer fnCancel()
for {
select {
case <-ctx.Done():
log.Error().Msg("stop context done, tasks has been stopped gracefully")
return
default:
}
if c.TasksDone() {
log.Info().Msg("all tasks has been stopped gracefully")
return
}
time.Sleep(time.Second)
}
}
func (c *SchedulerCycle) Done() <-chan struct{} {
return c.chDone
}
func (c *SchedulerCycle) Len() int {
return c.tasks.len()
}
// TasksDone checks whether all the tasks has been completed.
func (c *SchedulerCycle) TasksDone() bool {
return c.tasks.completed()
}
func (c *SchedulerCycle) GetTasksDetails() []TaskDetails {
return c.tasks.getAllDetails()
}
// GetTaskDetails returns the task details by id.
func (c *SchedulerCycle) GetTaskDetails(id uuid.UUID) TaskDetails {
return c.tasks.getDetails(id)
}
// Delay builds a task and adds it to the scheduler engine.
func (c *SchedulerCycle) Delay(fnJob FnJob, opts ...TaskOption) uuid.UUID {
select {
case <-c.Done():
log.Error().Msg("context done unable to add new job")
default:
}
t := NewTask(fnJob, opts...)
c.tasks.add(t)
c.exec(t)
log.Info().Str("id", t.GetID().String()).Msg("task added successfully")
return t.GetID()
}
// Abort aborts the task given by its id if it exists.
func (c *SchedulerCycle) Abort(id uuid.UUID) bool {
if t := c.getTask(id); t != nil {
t.Abort()
log.Info().Str("id", t.GetID().String()).Msg("abort task done")
return true
}
return false
}

93
scheduler_test.go Normal file
View File

@ -0,0 +1,93 @@
package scheduler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestScheduler(t *testing.T) {
ctx := context.Background()
var buf int
s := NewSchedulerCycle(ctx, 1)
taskID := s.Delay(func(ctx context.Context) (any, error) {
time.Sleep(50 * time.Millisecond)
buf += 1
return nil, nil
}, WithExecInterval(2*time.Millisecond))
assert.NotEmpty(t, taskID)
assert.False(t, s.TasksDone())
time.Sleep(2 * time.Millisecond)
details := s.GetTaskDetails(taskID)
assert.Equal(t, "running", details.State)
assert.LessOrEqual(t, details.ElapsedTime, 50*time.Millisecond)
time.Sleep(50 * time.Millisecond)
assert.True(t, s.TasksDone())
}
func TestSchedulerLoad(t *testing.T) {
ctx := context.Background()
s := NewSchedulerCycle(ctx, 1)
for i := 0; i < 500; i++ {
s.Delay(func(ctx context.Context) (any, error) {
time.Sleep(1 * time.Millisecond)
return nil, nil
}, WithExecInterval(1*time.Millisecond))
}
assert.Eventually(t, func() bool {
return s.TasksDone()
}, time.Second, 250*time.Millisecond)
}
func TestSchedulerExecInterval(t *testing.T) {
ctx := context.Background()
s := NewSchedulerCycle(ctx, 1)
s.Delay(
func(ctx context.Context) (any, error) {
return nil, ErrJobNotCompletedYet
},
WithMaxDuration(50*time.Millisecond),
WithExecInterval(2*time.Millisecond),
)
time.Sleep(100 * time.Millisecond)
assert.True(t, s.TasksDone())
}
func TestSchedulerContextDone(t *testing.T) {
ctx, fnCancel := context.WithCancel(context.Background())
s := NewSchedulerCycle(ctx, 1)
for i := 0; i < 250; i++ {
s.Delay(
func(ctx context.Context) (any, error) {
return nil, ErrJobNotCompletedYet
},
WithMaxDuration(100*time.Millisecond),
WithExecInterval(2*time.Millisecond),
)
}
go func() {
time.Sleep(50 * time.Millisecond)
fnCancel()
}()
<-s.Done()
}

451
task.go Normal file
View File

@ -0,0 +1,451 @@
package scheduler
import (
"context"
"errors"
"sync"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type State int
const (
Pending State = iota
Running
Success
Failed
Abort
Unknown
)
const UnknownState = "unknown"
func (s State) String() string {
return [...]string{"pending", "running", "success", "failed", "abort", "unknown"}[s]
}
var (
ErrJobAborted = errors.New("job has been aborted")
ErrJobNotCompletedYet = errors.New("job is not right state, retrying")
ErrTimeExceeded = errors.New("time exceeded")
ErrExecutionEngine = errors.New("execution engine")
)
type FnJob func(ctx context.Context) (any, error)
type FnResult func(ctx context.Context, res any)
type FnError func(ctx context.Context, err error)
type TaskDetails struct {
State string `json:"state"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
MaxDuration *time.Duration `json:"maxDuration,omitempty"`
Attempts uint32 `json:"attempts"`
Err string `json:"error"`
ElapsedTime time.Duration `json:"elapsedTime"`
AdditionalInfo map[string]string `json:"additionalInfos"`
}
func (td *TaskDetails) log() {
args := []any{}
args = append(args,
"state", td.State,
"createdAt", td.CreatedAt.Format("2006-01-02T15:04:05Z"),
"updatedAt", td.UpdatedAt.Format("2006-01-02T15:04:05Z"),
"elapsedTime", td.ElapsedTime.String(),
"attempts", td.Attempts,
)
if td.AdditionalInfo != nil {
for k, v := range td.AdditionalInfo {
args = append(args, k, v)
}
}
if td.Err != "" {
args = append(args, "error", td.Err)
log.Error().Any("args", args).Msg("task failed")
return
}
log.Info().Any("args", args).Msg("task execution done")
}
type task struct {
l sync.RWMutex
id uuid.UUID
createdAt time.Time
updatedAt *time.Time
state State
fnJob FnJob
fnSuccess FnResult
fnError FnError
attempts uint32
timer *time.Timer
maxDuration *time.Duration
execTimeout *time.Duration
execInterval *time.Duration
res any
err error
additionalInfos map[string]string
chAbort chan struct{}
}
type TaskOption func(t *task)
func WithMaxDuration(d time.Duration) TaskOption {
return func(t *task) {
t.maxDuration = &d
}
}
func WithFnSuccess(f FnResult) TaskOption {
return func(t *task) {
t.fnSuccess = f
}
}
func WithFnError(f FnError) TaskOption {
return func(t *task) {
t.fnError = f
}
}
func WithExecTimeout(d time.Duration) TaskOption {
return func(t *task) {
t.execTimeout = &d
}
}
func WithExecInterval(d time.Duration) TaskOption {
return func(t *task) {
t.execInterval = &d
}
}
func WithAdditionalInfos(k, v string) TaskOption {
return func(t *task) {
if k == "" || v == "" {
return
}
if t.additionalInfos == nil {
t.additionalInfos = map[string]string{}
}
t.additionalInfos[k] = v
}
}
// NewTask builds a task.
// Here the options details that can be set to the task:
// - WithMaxDuration(time.Duration): the task will stop executing if the duration is exceeded (raise ErrTimeExceeded)
// - WithFnSuccess(FnResult): call a function after a success execution
// - WithFnError(FnError): call a function if an error occurred
// - WithExecTimeout(time.Duration): sets a timeout for the task execution
// - WithAdditionalInfos(k, v string): key-value additional informations (does not inferred with the task execution, log purpose)
//
// Scheduler options:
// - WithExecInterval(time.Duration): specify the execution interval
func NewTask(f FnJob, opts ...TaskOption) *task {
t := task{
id: uuid.New(),
createdAt: time.Now().UTC(),
state: Pending,
fnJob: f,
chAbort: make(chan struct{}, 1),
}
for _, o := range opts {
o(&t)
}
return &t
}
func (t *task) setState(s State) {
t.l.Lock()
defer t.l.Unlock()
now := time.Now().UTC()
t.updatedAt = &now
t.state = s
}
func (t *task) setSuccess(ctx context.Context, res any) {
t.l.Lock()
defer t.l.Unlock()
now := time.Now().UTC()
t.updatedAt = &now
t.state = Success
t.res = res
if t.fnSuccess != nil {
t.fnSuccess(ctx, res)
}
}
func (t *task) setFail(ctx context.Context, err error) {
t.l.Lock()
defer t.l.Unlock()
now := time.Now().UTC()
t.updatedAt = &now
if t.state != Abort {
t.state = Failed
}
t.err = err
if t.fnError != nil {
t.fnError(ctx, err)
}
}
func (t *task) setTimer(tm *time.Timer) {
t.l.Lock()
defer t.l.Unlock()
t.timer = tm
}
func (t *task) stopTimer() {
t.l.Lock()
defer t.l.Unlock()
if t.timer != nil {
t.timer.Stop()
}
}
func (t *task) incr() {
t.l.Lock()
defer t.l.Unlock()
t.attempts += 1
}
func (t *task) log() {
td := t.IntoDetails()
td.log()
}
func (t *task) GetID() uuid.UUID {
t.l.RLock()
defer t.l.RUnlock()
return t.id
}
func (t *task) GetAttempts() uint32 {
t.l.RLock()
defer t.l.RUnlock()
return t.attempts
}
func (t *task) GetState() State {
t.l.RLock()
defer t.l.RUnlock()
return t.state
}
// TimeExceeded checks if the task does not reached its max duration execution (maxDuration).
func (t *task) TimeExceeded() bool {
t.l.RLock()
defer t.l.RUnlock()
if md := t.maxDuration; md != nil {
return time.Since(t.createdAt) >= *md
}
return false
}
func (t *task) Abort() {
t.stopTimer()
t.chAbort <- struct{}{}
}
func (t *task) Run(ctx context.Context) {
if t.TimeExceeded() {
t.setFail(ctx, ErrTimeExceeded)
return
}
if s := t.GetState(); s != Pending {
log.Error().Msg("unable to launch a task that not in pending state")
t.setFail(ctx, ErrExecutionEngine)
return
}
var ctxExec context.Context
var fnCancel context.CancelFunc
if t.execTimeout != nil {
ctxExec, fnCancel = context.WithTimeout(ctx, *t.execTimeout)
} else {
ctxExec, fnCancel = context.WithCancel(ctx)
}
defer fnCancel()
t.incr()
t.setState(Running)
log.Info().Str("id", t.GetID().String()).Msg("task is running...")
go func() {
for range t.chAbort {
t.setState(Abort)
fnCancel()
return
}
}()
defer t.log()
res, err := t.fnJob(ctxExec)
if err != nil {
if errors.Is(err, ErrJobNotCompletedYet) {
t.setState(Pending)
return
}
t.setFail(ctx, err)
return
}
t.setSuccess(ctx, res)
}
func (t *task) IntoDetails() TaskDetails {
t.l.RLock()
defer t.l.RUnlock()
td := TaskDetails{
CreatedAt: t.createdAt,
UpdatedAt: t.updatedAt,
State: t.state.String(),
Attempts: t.attempts,
MaxDuration: t.maxDuration,
}
if t.state == Pending || t.state == Running {
td.ElapsedTime = time.Since(t.createdAt)
} else {
td.ElapsedTime = t.updatedAt.Sub(t.createdAt)
}
if err := t.err; err != nil {
td.Err = err.Error()
}
if t.additionalInfos != nil {
td.AdditionalInfo = t.additionalInfos
}
return td
}
type tasks struct {
l sync.RWMutex
s map[uuid.UUID]*task
}
func newTasks() tasks {
return tasks{
s: make(map[uuid.UUID]*task),
}
}
func (ts *tasks) add(t *task) {
ts.l.Lock()
defer ts.l.Unlock()
ts.s[t.GetID()] = t
}
func (ts *tasks) delete(t *task) {
ts.l.Lock()
defer ts.l.Unlock()
delete(ts.s, t.GetID())
}
func (ts *tasks) get(id uuid.UUID) *task {
ts.l.RLock()
defer ts.l.RUnlock()
t, ok := ts.s[id]
if !ok {
return nil
}
return t
}
func (ts *tasks) len() int {
ts.l.RLock()
defer ts.l.RUnlock()
return len(ts.s)
}
func (ts *tasks) completed() bool {
ts.l.RLock()
defer ts.l.RUnlock()
for _, t := range ts.s {
if t.GetState() == Pending || t.GetState() == Running {
return false
}
}
return true
}
func (ts *tasks) getAllDetails() []TaskDetails {
ts.l.RLock()
defer ts.l.RUnlock()
details := []TaskDetails{}
for _, t := range ts.s {
details = append(details, t.IntoDetails())
}
return details
}
func (ts *tasks) getDetails(id uuid.UUID) TaskDetails {
ts.l.RLock()
defer ts.l.RUnlock()
t, ok := ts.s[id]
if !ok {
return TaskDetails{State: UnknownState}
}
return t.IntoDetails()
}
func (ts *tasks) abort() {
ts.l.RLock()
defer ts.l.RUnlock()
for _, t := range ts.s {
t.Abort()
}
}

243
task_test.go Normal file
View File

@ -0,0 +1,243 @@
package scheduler
import (
"context"
"errors"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestTask(t *testing.T) {
ctx := context.Background()
var i int
task := NewTask(
func(ctx context.Context) (any, error) {
i += 1
return nil, nil
},
)
task.Run(ctx)
assert.Equal(t, nil, task.res)
assert.NotEmpty(t, task.updatedAt)
assert.Equal(t, 1, i)
assert.Equal(t, 1, int(task.GetAttempts()))
}
func TestAbortTask(t *testing.T) {
ctx := context.Background()
task := NewTask(
func(ctx context.Context) (any, error) {
<-ctx.Done()
return nil, ctx.Err()
},
)
timer := time.NewTicker(200 * time.Millisecond)
go func() {
<-timer.C
task.Abort()
}()
task.Run(ctx)
assert.Equal(t, Abort, task.GetState())
}
func TestTaskContextDone(t *testing.T) {
ctx, fnCancel := context.WithCancel(context.Background())
task := NewTask(
func(ctx context.Context) (any, error) {
<-ctx.Done()
return nil, ctx.Err()
},
)
timer := time.NewTicker(200 * time.Millisecond)
go func() {
<-timer.C
fnCancel()
}()
task.Run(ctx)
assert.Equal(t, Failed, task.GetState())
}
func TestTaskFnSuccess(t *testing.T) {
ctx := context.Background()
var result int
task := NewTask(
func(ctx context.Context) (any, error) {
return 3, nil
},
WithFnSuccess(func(ctx context.Context, res any) {
t, ok := res.(int)
if !ok {
return
}
result = t * 2
}),
)
task.Run(ctx)
assert.Equal(t, Success, task.GetState())
assert.Equal(t, 6, result)
}
func TestTaskFnError(t *testing.T) {
ctx := context.Background()
var result error
task := NewTask(
func(ctx context.Context) (any, error) {
return 3, errors.New("error occurred...")
},
WithFnError(func(ctx context.Context, err error) {
result = err
}),
)
task.Run(ctx)
assert.Equal(t, Failed, task.GetState())
assert.Equal(t, "error occurred...", result.Error())
}
func TestTaskWithErrJobNotCompletedYet(t *testing.T) {
ctx := context.Background()
var attempts int
task := NewTask(
func(ctx context.Context) (any, error) {
if attempts < 2 {
attempts += 1
return nil, ErrJobNotCompletedYet
}
return "ok", nil
},
)
for i := 0; i < 2; i++ {
task.Run(ctx)
assert.Equal(t, Pending, task.GetState())
}
task.Run(ctx)
assert.Equal(t, Success, task.GetState())
assert.Equal(t, 3, int(task.GetAttempts()))
}
func TestTaskTimeExceeded(t *testing.T) {
ctx := context.Background()
task := NewTask(
func(ctx context.Context) (any, error) {
return "ko", nil
},
WithMaxDuration(5*time.Millisecond),
)
time.Sleep(10 * time.Millisecond)
task.Run(ctx)
assert.Equal(t, Failed, task.GetState())
assert.Equal(t, 0, int(task.GetAttempts()))
}
func TestTaskExecTimeout(t *testing.T) {
ctx := context.Background()
task := NewTask(
func(ctx context.Context) (any, error) {
<-ctx.Done()
return nil, ctx.Err()
},
WithExecTimeout(5*time.Millisecond),
)
time.Sleep(10 * time.Millisecond)
task.Run(ctx)
assert.Equal(t, Failed, task.GetState())
assert.Equal(t, 1, int(task.GetAttempts()))
}
func TestTaskDetails(t *testing.T) {
ctx := context.Background()
task := NewTask(
func(ctx context.Context) (any, error) {
return "coucou", nil
},
)
details := task.IntoDetails()
assert.Equal(t, 0, int(details.Attempts))
assert.Equal(t, "pending", details.State)
assert.False(t, details.CreatedAt.IsZero())
assert.Empty(t, details.UpdatedAt)
assert.Nil(t, details.MaxDuration)
assert.Empty(t, details.Err)
assert.NotEmpty(t, details.ElapsedTime)
task.Run(ctx)
details = task.IntoDetails()
assert.Equal(t, 1, int(details.Attempts))
assert.Equal(t, "success", details.State)
assert.False(t, details.CreatedAt.IsZero())
assert.NotEmpty(t, details.UpdatedAt)
assert.Nil(t, details.MaxDuration)
assert.Empty(t, details.Err)
assert.NotEmpty(t, details.ElapsedTime)
}
func TestTaskAdditionalInfos(t *testing.T) {
t.Run("with key value", func(t *testing.T) {
elementID := uuid.NewString()
task := NewTask(
func(ctx context.Context) (any, error) {
return "yo", nil
},
WithAdditionalInfos("transportId", elementID),
WithAdditionalInfos("element", "transport"),
)
assert.Equal(t, elementID, task.additionalInfos["transportId"])
assert.Equal(t, "transport", task.additionalInfos["element"])
})
t.Run("with empty key", func(t *testing.T) {
elementID := uuid.NewString()
task := NewTask(
func(ctx context.Context) (any, error) {
return "hello", nil
},
WithAdditionalInfos("", elementID),
WithAdditionalInfos("element", "transport"),
)
assert.Equal(t, "transport", task.additionalInfos["element"])
})
t.Run("with empty infos", func(t *testing.T) {
task := NewTask(
func(ctx context.Context) (any, error) {
return "hey", nil
},
)
assert.Nil(t, task.additionalInfos)
})
}