2024-11-16 11:00:46 +01:00

452 lines
7.9 KiB
Go

package scheduler
import (
"context"
"errors"
"sync"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type State int
const (
Pending State = iota
Running
Success
Failed
Abort
Unknown
)
const UnknownState = "unknown"
func (s State) String() string {
return [...]string{"pending", "running", "success", "failed", "abort", "unknown"}[s]
}
var (
ErrJobAborted = errors.New("job has been aborted")
ErrJobNotCompletedYet = errors.New("job is not right state, retrying")
ErrTimeExceeded = errors.New("time exceeded")
ErrExecutionEngine = errors.New("execution engine")
)
type FnJob func(ctx context.Context) (any, error)
type FnResult func(ctx context.Context, res any)
type FnError func(ctx context.Context, err error)
type TaskDetails struct {
State string `json:"state"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
MaxDuration *time.Duration `json:"maxDuration,omitempty"`
Attempts uint32 `json:"attempts"`
Err string `json:"error"`
ElapsedTime time.Duration `json:"elapsedTime"`
AdditionalInfo map[string]string `json:"additionalInfos"`
}
func (td *TaskDetails) log() {
args := []any{}
args = append(args,
"state", td.State,
"createdAt", td.CreatedAt.Format("2006-01-02T15:04:05Z"),
"updatedAt", td.UpdatedAt.Format("2006-01-02T15:04:05Z"),
"elapsedTime", td.ElapsedTime.String(),
"attempts", td.Attempts,
)
if td.AdditionalInfo != nil {
for k, v := range td.AdditionalInfo {
args = append(args, k, v)
}
}
if td.Err != "" {
args = append(args, "error", td.Err)
log.Error().Any("args", args).Msg("task failed")
return
}
log.Info().Any("args", args).Msg("task execution done")
}
type task struct {
l sync.RWMutex
id uuid.UUID
createdAt time.Time
updatedAt *time.Time
state State
fnJob FnJob
fnSuccess FnResult
fnError FnError
attempts uint32
timer *time.Timer
maxDuration *time.Duration
execTimeout *time.Duration
execInterval *time.Duration
res any
err error
additionalInfos map[string]string
chAbort chan struct{}
}
type TaskOption func(t *task)
func WithMaxDuration(d time.Duration) TaskOption {
return func(t *task) {
t.maxDuration = &d
}
}
func WithFnSuccess(f FnResult) TaskOption {
return func(t *task) {
t.fnSuccess = f
}
}
func WithFnError(f FnError) TaskOption {
return func(t *task) {
t.fnError = f
}
}
func WithExecTimeout(d time.Duration) TaskOption {
return func(t *task) {
t.execTimeout = &d
}
}
func WithExecInterval(d time.Duration) TaskOption {
return func(t *task) {
t.execInterval = &d
}
}
func WithAdditionalInfos(k, v string) TaskOption {
return func(t *task) {
if k == "" || v == "" {
return
}
if t.additionalInfos == nil {
t.additionalInfos = map[string]string{}
}
t.additionalInfos[k] = v
}
}
// NewTask builds a task.
// Here the options details that can be set to the task:
// - WithMaxDuration(time.Duration): the task will stop executing if the duration is exceeded (raise ErrTimeExceeded)
// - WithFnSuccess(FnResult): call a function after a success execution
// - WithFnError(FnError): call a function if an error occurred
// - WithExecTimeout(time.Duration): sets a timeout for the task execution
// - WithAdditionalInfos(k, v string): key-value additional informations (does not inferred with the task execution, log purpose)
//
// Scheduler options:
// - WithExecInterval(time.Duration): specify the execution interval
func NewTask(f FnJob, opts ...TaskOption) *task {
t := task{
id: uuid.New(),
createdAt: time.Now().UTC(),
state: Pending,
fnJob: f,
chAbort: make(chan struct{}, 1),
}
for _, o := range opts {
o(&t)
}
return &t
}
func (t *task) setState(s State) {
t.l.Lock()
defer t.l.Unlock()
now := time.Now().UTC()
t.updatedAt = &now
t.state = s
}
func (t *task) setSuccess(ctx context.Context, res any) {
t.l.Lock()
defer t.l.Unlock()
now := time.Now().UTC()
t.updatedAt = &now
t.state = Success
t.res = res
if t.fnSuccess != nil {
t.fnSuccess(ctx, res)
}
}
func (t *task) setFail(ctx context.Context, err error) {
t.l.Lock()
defer t.l.Unlock()
now := time.Now().UTC()
t.updatedAt = &now
if t.state != Abort {
t.state = Failed
}
t.err = err
if t.fnError != nil {
t.fnError(ctx, err)
}
}
func (t *task) setTimer(tm *time.Timer) {
t.l.Lock()
defer t.l.Unlock()
t.timer = tm
}
func (t *task) stopTimer() {
t.l.Lock()
defer t.l.Unlock()
if t.timer != nil {
t.timer.Stop()
}
}
func (t *task) incr() {
t.l.Lock()
defer t.l.Unlock()
t.attempts += 1
}
func (t *task) log() {
td := t.IntoDetails()
td.log()
}
func (t *task) GetID() uuid.UUID {
t.l.RLock()
defer t.l.RUnlock()
return t.id
}
func (t *task) GetAttempts() uint32 {
t.l.RLock()
defer t.l.RUnlock()
return t.attempts
}
func (t *task) GetState() State {
t.l.RLock()
defer t.l.RUnlock()
return t.state
}
// TimeExceeded checks if the task does not reached its max duration execution (maxDuration).
func (t *task) TimeExceeded() bool {
t.l.RLock()
defer t.l.RUnlock()
if md := t.maxDuration; md != nil {
return time.Since(t.createdAt) >= *md
}
return false
}
func (t *task) Abort() {
t.stopTimer()
t.chAbort <- struct{}{}
}
func (t *task) Run(ctx context.Context) {
if t.TimeExceeded() {
t.setFail(ctx, ErrTimeExceeded)
return
}
if s := t.GetState(); s != Pending {
log.Error().Msg("unable to launch a task that not in pending state")
t.setFail(ctx, ErrExecutionEngine)
return
}
var ctxExec context.Context
var fnCancel context.CancelFunc
if t.execTimeout != nil {
ctxExec, fnCancel = context.WithTimeout(ctx, *t.execTimeout)
} else {
ctxExec, fnCancel = context.WithCancel(ctx)
}
defer fnCancel()
t.incr()
t.setState(Running)
log.Info().Str("id", t.GetID().String()).Msg("task is running...")
go func() {
for range t.chAbort {
t.setState(Abort)
fnCancel()
return
}
}()
defer t.log()
res, err := t.fnJob(ctxExec)
if err != nil {
if errors.Is(err, ErrJobNotCompletedYet) {
t.setState(Pending)
return
}
t.setFail(ctx, err)
return
}
t.setSuccess(ctx, res)
}
func (t *task) IntoDetails() TaskDetails {
t.l.RLock()
defer t.l.RUnlock()
td := TaskDetails{
CreatedAt: t.createdAt,
UpdatedAt: t.updatedAt,
State: t.state.String(),
Attempts: t.attempts,
MaxDuration: t.maxDuration,
}
if t.state == Pending || t.state == Running {
td.ElapsedTime = time.Since(t.createdAt)
} else {
td.ElapsedTime = t.updatedAt.Sub(t.createdAt)
}
if err := t.err; err != nil {
td.Err = err.Error()
}
if t.additionalInfos != nil {
td.AdditionalInfo = t.additionalInfos
}
return td
}
type tasks struct {
l sync.RWMutex
s map[uuid.UUID]*task
}
func newTasks() tasks {
return tasks{
s: make(map[uuid.UUID]*task),
}
}
func (ts *tasks) add(t *task) {
ts.l.Lock()
defer ts.l.Unlock()
ts.s[t.GetID()] = t
}
func (ts *tasks) delete(t *task) {
ts.l.Lock()
defer ts.l.Unlock()
delete(ts.s, t.GetID())
}
func (ts *tasks) get(id uuid.UUID) *task {
ts.l.RLock()
defer ts.l.RUnlock()
t, ok := ts.s[id]
if !ok {
return nil
}
return t
}
func (ts *tasks) len() int {
ts.l.RLock()
defer ts.l.RUnlock()
return len(ts.s)
}
func (ts *tasks) completed() bool {
ts.l.RLock()
defer ts.l.RUnlock()
for _, t := range ts.s {
if t.GetState() == Pending || t.GetState() == Running {
return false
}
}
return true
}
func (ts *tasks) getAllDetails() []TaskDetails {
ts.l.RLock()
defer ts.l.RUnlock()
details := []TaskDetails{}
for _, t := range ts.s {
details = append(details, t.IntoDetails())
}
return details
}
func (ts *tasks) getDetails(id uuid.UUID) TaskDetails {
ts.l.RLock()
defer ts.l.RUnlock()
t, ok := ts.s[id]
if !ok {
return TaskDetails{State: UnknownState}
}
return t.IntoDetails()
}
func (ts *tasks) abort() {
ts.l.RLock()
defer ts.l.RUnlock()
for _, t := range ts.s {
t.Abort()
}
}