177 lines
3.1 KiB
Go
177 lines
3.1 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
var (
|
|
ErrSchedulerMaxCapacityReached = errors.New("unable to add new task, max capacity reached")
|
|
ErrSchedulerContextDone = errors.New("context done, scheduler stopped")
|
|
)
|
|
|
|
type TaskStatus string
|
|
|
|
const (
|
|
Pending TaskStatus = "pending"
|
|
Running TaskStatus = "running"
|
|
Success = "success"
|
|
Failed = "failed"
|
|
)
|
|
|
|
type FnJob func() error
|
|
|
|
type taskStore struct {
|
|
l sync.RWMutex
|
|
tasks map[string]*Task
|
|
}
|
|
|
|
func newTaskStore() taskStore {
|
|
return taskStore{
|
|
tasks: map[string]*Task{},
|
|
}
|
|
}
|
|
|
|
func (ts *taskStore) push(task *Task) {
|
|
ts.l.Lock()
|
|
defer ts.l.Unlock()
|
|
|
|
ts.tasks[task.Name] = task
|
|
}
|
|
|
|
func (ts *taskStore) setStatus(task *Task, status TaskStatus) {
|
|
ts.l.Lock()
|
|
defer ts.l.Unlock()
|
|
|
|
if _, ok := ts.tasks[task.Name]; !ok {
|
|
log.Warn().Str("name", task.Name).Msg("unable to update task status, does not exist")
|
|
return
|
|
}
|
|
|
|
ts.tasks[task.Name].Status = status
|
|
}
|
|
|
|
func (ts *taskStore) len() int {
|
|
ts.l.RLock()
|
|
defer ts.l.RUnlock()
|
|
|
|
return len(ts.tasks)
|
|
}
|
|
|
|
type Task struct {
|
|
Name string
|
|
Job FnJob
|
|
Status TaskStatus
|
|
Next []*Task
|
|
}
|
|
|
|
func NewTask(name string, job FnJob, next ...*Task) *Task {
|
|
return &Task{
|
|
Name: name,
|
|
Job: job,
|
|
Next: next,
|
|
Status: Pending,
|
|
}
|
|
}
|
|
|
|
type Scheduler struct {
|
|
ctx context.Context
|
|
fnCancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
|
|
capacity atomic.Uint32
|
|
workers int
|
|
|
|
chTasks chan *Task
|
|
tasks taskStore
|
|
}
|
|
|
|
func NewScheduler(ctx context.Context, capacity, workers int) *Scheduler {
|
|
ctxChild, fnCancel := context.WithCancel(ctx)
|
|
s := Scheduler{
|
|
ctx: ctxChild,
|
|
fnCancel: fnCancel,
|
|
capacity: atomic.Uint32{},
|
|
workers: workers,
|
|
chTasks: make(chan *Task, capacity),
|
|
tasks: newTaskStore(),
|
|
wg: sync.WaitGroup{},
|
|
}
|
|
s.capacity.Add(uint32(capacity))
|
|
s.run()
|
|
|
|
return &s
|
|
}
|
|
|
|
func (s *Scheduler) run() {
|
|
for i := 0; i < s.workers; i++ {
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
for {
|
|
select {
|
|
case t := <-s.chTasks:
|
|
s.tasks.setStatus(t, Running)
|
|
|
|
if err := t.Job(); err != nil {
|
|
log.Err(err).Str("task", t.Name).Msg("error executing task")
|
|
s.tasks.setStatus(t, Failed)
|
|
continue
|
|
}
|
|
|
|
s.tasks.setStatus(t, Success)
|
|
|
|
for _, nt := range t.Next {
|
|
s.Submit(nt)
|
|
}
|
|
case <-s.ctx.Done():
|
|
log.Warn().Msg("context done, stopping worker...")
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
func (s *Scheduler) Stop() {
|
|
s.fnCancel()
|
|
}
|
|
|
|
func (s *Scheduler) Submit(task *Task) error {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
log.Error().Msg("unable to submit new task, scheduler is stopping...")
|
|
return ErrSchedulerContextDone
|
|
default:
|
|
}
|
|
|
|
cap := s.capacity.Load()
|
|
if s.tasks.len() >= int(cap) {
|
|
return ErrSchedulerMaxCapacityReached
|
|
}
|
|
|
|
s.tasks.push(task)
|
|
s.chTasks <- task
|
|
return nil
|
|
}
|
|
|
|
func (s *Scheduler) Done() <-chan struct{} {
|
|
chDone := make(chan struct{})
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
log.Info().Msg("waiting for scheduler task completion...")
|
|
s.wg.Wait()
|
|
chDone <- struct{}{}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return chDone
|
|
}
|