199 lines
4.2 KiB
Go
199 lines
4.2 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
const (
|
|
defaultAPISessionExpirationDuration = 5 * 60 * time.Second
|
|
defaultAdminPassword = "admin"
|
|
defaultAdminUsername = "admin"
|
|
)
|
|
|
|
var (
|
|
ErrSessionIDCollision = errors.New("sessionId collision")
|
|
ErrUnauthorized = errors.New("unauthorized")
|
|
)
|
|
|
|
var adminPassword = sync.OnceValue[string](func() string {
|
|
adminPassword := os.Getenv("API_ADMIN_PASSWORD")
|
|
if adminPassword == "" {
|
|
log.Error().Msg("API_ADMIN_PASSWORD env var is empty, set to default")
|
|
return defaultAdminPassword
|
|
}
|
|
return adminPassword
|
|
})
|
|
|
|
var adminUsername = sync.OnceValue[string](func() string {
|
|
adminUsername := os.Getenv("API_ADMIN_USERNAME")
|
|
if adminUsername == "" {
|
|
log.Error().Msg("API_ADMIN_USERNAME env var is empty, set to default")
|
|
return defaultAdminUsername
|
|
}
|
|
return adminUsername
|
|
})
|
|
|
|
var sessionExpirationTime = sync.OnceValue[time.Duration](func() time.Duration {
|
|
sessionExpirationDuration, err := strconv.Atoi(os.Getenv("API_SESSION_EXPIRATION_DURATION"))
|
|
if err != nil {
|
|
log.Warn().Err(err).Dur("default", defaultAPISessionExpirationDuration).Msg("unable to load API_SESSION_EXPIRATION_DURATION, set to default")
|
|
return defaultAPISessionExpirationDuration
|
|
}
|
|
return time.Duration(sessionExpirationDuration * int(time.Second))
|
|
})
|
|
|
|
func generateSessionID() (string, error) {
|
|
sessionID := make([]byte, 32) //nolint
|
|
if _, err := rand.Read(sessionID); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return hex.EncodeToString(sessionID), nil
|
|
}
|
|
|
|
type Session struct {
|
|
l sync.RWMutex
|
|
sessionID string
|
|
expirationTime time.Time
|
|
}
|
|
|
|
func (s *Session) GenerateCookie(isSecure bool) *http.Cookie {
|
|
s.l.RLock()
|
|
defer s.l.RUnlock()
|
|
|
|
return &http.Cookie{
|
|
Name: "session_id",
|
|
Value: s.sessionID,
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
Expires: s.expirationTime,
|
|
}
|
|
}
|
|
|
|
type IAuthenticate interface {
|
|
IsLogged(r *http.Request) bool
|
|
Authenticate(username, password string) (*Session, error)
|
|
IsSecure() bool
|
|
}
|
|
|
|
var _ IAuthenticate = (*Authentication)(nil)
|
|
|
|
type Authentication struct {
|
|
l sync.RWMutex
|
|
|
|
ctx context.Context
|
|
fnCancel context.CancelFunc
|
|
|
|
sessions map[string]*Session
|
|
isSecure bool
|
|
}
|
|
|
|
func NewAuthentication(ctx context.Context, isSecure bool) *Authentication {
|
|
ctxChild, fnCancel := context.WithCancel(ctx)
|
|
|
|
s := &Authentication{
|
|
ctx: ctxChild,
|
|
fnCancel: fnCancel,
|
|
sessions: map[string]*Session{},
|
|
isSecure: isSecure,
|
|
}
|
|
s.purgeWorker()
|
|
|
|
return s
|
|
}
|
|
|
|
func (a *Authentication) purge() {
|
|
a.l.Lock()
|
|
defer a.l.Unlock()
|
|
|
|
now := time.Now()
|
|
toDelete := []*Session{}
|
|
for _, session := range a.sessions {
|
|
if now.After(session.expirationTime) {
|
|
toDelete = append(toDelete, session)
|
|
}
|
|
}
|
|
|
|
for _, session := range toDelete {
|
|
log.Debug().Str("sessionId", session.sessionID).Msg("purge expired session")
|
|
delete(a.sessions, session.sessionID)
|
|
}
|
|
}
|
|
|
|
func (a *Authentication) purgeWorker() {
|
|
ticker := time.NewTicker(10 * time.Second) //nolint
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
a.purge()
|
|
case <-a.ctx.Done():
|
|
log.Info().Msg("purge worker stopped")
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (a *Authentication) IsSecure() bool {
|
|
return a.isSecure
|
|
}
|
|
|
|
func (a *Authentication) Stop() {
|
|
a.fnCancel()
|
|
}
|
|
|
|
func (a *Authentication) Done() <-chan struct{} {
|
|
return a.ctx.Done()
|
|
}
|
|
|
|
func (a *Authentication) Authenticate(username, password string) (*Session, error) {
|
|
if username != adminUsername() || password != adminPassword() {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
|
|
sessionID, err := generateSessionID()
|
|
if err != nil {
|
|
log.Err(err).Msg("unable to generate sessionId")
|
|
return nil, err
|
|
}
|
|
|
|
a.l.Lock()
|
|
defer a.l.Unlock()
|
|
|
|
if _, ok := a.sessions[sessionID]; ok {
|
|
log.Error().Str("sessionId", sessionID).Msg("sessionId collision")
|
|
return nil, ErrSessionIDCollision
|
|
}
|
|
|
|
now := time.Now().Add(sessionExpirationTime())
|
|
session := Session{expirationTime: now, sessionID: sessionID}
|
|
a.sessions[sessionID] = &session
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
func (a *Authentication) IsLogged(r *http.Request) bool {
|
|
cookie, err := r.Cookie("session_id")
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
a.l.RLock()
|
|
defer a.l.RUnlock()
|
|
|
|
_, ok := a.sessions[cookie.Value]
|
|
return ok
|
|
}
|