diff --git a/main.go b/main.go index d25bafd..494b686 100644 --- a/main.go +++ b/main.go @@ -6,8 +6,6 @@ import ( "librapi/services" "os" "os/signal" - "strconv" - "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -16,17 +14,6 @@ import ( "librapi/handlers/upload" ) -const DefaultPort = 8585 - -var APIPort = sync.OnceValue[int](func() int { - port, err := strconv.Atoi(os.Getenv("API_PORT")) - if err != nil { - log.Debug().Err(err).Msg("unable to load API_PORT") - return DefaultPort - } - return port -}) - func initLogger() { zerolog.TimeFieldFormat = zerolog.TimeFormatUnix log.Logger = log.With().Caller().Logger().Output(zerolog.ConsoleWriter{Out: os.Stderr}) @@ -42,7 +29,7 @@ func main() { srv := server.NewServer( ctx, - APIPort(), + services.GetEnv().GetPort(), server.NewHandler("/upload", upload.Handler(sessionStore)), server.NewHandler("/login", login.Handler(sessionStore)), ) diff --git a/server/server.go b/server/server.go index ce91e82..325be6e 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "errors" + "librapi/services" "net/http" "strconv" "time" @@ -13,7 +14,6 @@ import ( const ( ServerShutdownTimeout = 10 * time.Second ServerReadTimeout = 5 * time.Second - DefaultPort = 8888 ) type Handler struct { @@ -35,8 +35,8 @@ type ServerOption func() func NewServer(ctx context.Context, port int, handlers ...Handler) Server { if port == 0 { - log.Warn().Int("port", DefaultPort).Msg("no port detected, set to default") - port = DefaultPort + log.Warn().Int("port", services.GetEnv().GetPort()).Msg("no port detected, set to default") + port = services.GetEnv().GetPort() } srvmux := http.NewServeMux() diff --git a/services/environments.go b/services/environments.go new file mode 100644 index 0000000..4c3652f --- /dev/null +++ b/services/environments.go @@ -0,0 +1,63 @@ +package services + +import ( + "os" + "strconv" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +const defaultAPISessionExpirationDuration = 30 * time.Second +const defaultPort = 8585 + +var GetEnv = sync.OnceValue[env](newEnv) + +type env struct { + adminUsername string + adminPassword string + sessionExpirationDuration time.Duration + port int + isSecure bool +} + +func (e env) GetCredentials() (username, password string) { + return e.adminUsername, e.adminPassword +} + +func (e env) GetSessionExpirationDuration() time.Duration { + return e.sessionExpirationDuration +} + +func (e env) GetPort() int { + return e.port +} + +func (e env) IsSecure() bool { + return e.isSecure +} + +func newEnv() env { + env := env{ + adminUsername: os.Getenv("API_ADMIN_USERNAME"), + adminPassword: os.Getenv("API_ADMIN_PASSWORD"), + isSecure: os.Getenv("API_SECURE") == "true", + } + + sessionExpirationDuration, err := strconv.Atoi(os.Getenv("API_SESSION_EXPIRATION_DURATION")) + env.sessionExpirationDuration = time.Duration(sessionExpirationDuration) + if err != nil { + log.Warn().Err(err).Dur("default", defaultAPISessionExpirationDuration).Msg("unable to load API_SESSION_EXPIRATION_DURATION, set to default") + env.sessionExpirationDuration = defaultAPISessionExpirationDuration + } + + port, err := strconv.Atoi(os.Getenv("API_PORT")) + env.port = port + if err != nil { + log.Warn().Err(err).Int("default", defaultPort).Msg("unable to load API_PORT, set to default") + env.port = defaultPort + } + + return env +} diff --git a/services/sessions.go b/services/sessions.go index 1349a8a..aa55a08 100644 --- a/services/sessions.go +++ b/services/sessions.go @@ -6,29 +6,12 @@ import ( "encoding/hex" "errors" "net/http" - "os" - "strconv" "sync" "time" "github.com/rs/zerolog/log" ) -const defaultAPISessionExpirationDuration = 30 * time.Second - -var APISessionExpirationDuration = sync.OnceValue[time.Duration](func() time.Duration { - expirationDuration, err := strconv.Atoi(os.Getenv("API_SESSION_EXPIRATION_DURATION")) - if err != nil { - log.Debug().Err(err).Msg("unable to load API_SESSION_EXPIRATION_DURATION") - return defaultAPISessionExpirationDuration - } - return time.Duration(expirationDuration * int(time.Second)) -}) - -var APISecure = sync.OnceValue[bool](func() bool { - return os.Getenv("API_SECURE") == "true" -}) - var ( ErrSessionIDCollision = errors.New("sessionId collision") ErrUnauthorized = errors.New("unauthorized") @@ -57,7 +40,7 @@ func (s *Session) GenerateCookie() *http.Cookie { Name: "session_id", Value: s.sessionID, HttpOnly: true, - Secure: APISecure(), + Secure: GetEnv().isSecure, Expires: s.expirationTime, } } @@ -141,7 +124,7 @@ func (s *SessionStore) NewSession() (*Session, error) { return nil, ErrSessionIDCollision } - now := time.Now().Add(APISessionExpirationDuration()) + now := time.Now().Add(GetEnv().GetSessionExpirationDuration()) session := Session{expirationTime: now, sessionID: sessionID} s.sessions[sessionID] = &session