import logging import os from dataclasses import dataclass, field from datetime import datetime as dt from enum import Enum from threading import Event, Thread from typing import Any, Callable from urllib.parse import urljoin from uuid import UUID import requests import urllib3 from dotenv import dotenv_values from requests import Session from sseclient import SSEClient from .exceptions import UnauthorizedException, WhereIsException API_DEFAULT_URL = "https://api.locame.duckdns.org" __all__ = ["Client", "OrderField"] def refresh(): """ Catch 401 status code (UnauthorizedException) refresh the access token and retry. """ def decorator(func): def wrapper(*args, **kwargs): if len(args) > 0 and isinstance(args[0], Client): for i in range(2): try: return func(*args, **kwargs) except UnauthorizedException as e: if i == 1: # second attempt logging.error( "second call attempt failed after refreshing token" ) raise e logging.warning("refresh access token...") args[0]._refresh() return wrapper return decorator class OrderField(Enum): """ Ordering query param available fields. - AscDateStart: date start ascending order - DesDateStart: date start descending order - AscDateEnd: date end ascending order - DesDateEnd: date end descending order """ AscDateStart = "date_start" DesDateStart = "-date_start" AscDateEnd = "date_end" DesDateEnd = "-date_end" @dataclass(frozen=True, slots=True) class SessionWatcher: """ Handle the SSE client connection in a daemon thread. NOTE: Should not be instanciated directly, use the `Client` instead. """ session_id: UUID _thread: Thread _client: SSEClient _event: Event @classmethod def from_session_id( cls, base_url: str, id_: UUID, headers: dict[str, Any], callback: Callable[[str], None] | None = None, verify: bool = True, ) -> "SessionWatcher": session_url = urljoin(base_url, f"/sessions/{id_}/events/") headers = {**headers, "Accept": "text/event-stream"} resp = requests.get(session_url, stream=True, headers=headers, verify=verify) client = SSEClient(resp) # type: ignore event = Event() def _job(): logging.debug(f"SSE client daemon started for session: {id_}") for evt in client.events(): logging.debug(f"event received for session: {id_}") if callback is not None: callback(evt.data) if event.is_set(): logging.debug(f"SSE client daemon stopped for session: {id_}") break t = Thread(target=_job, daemon=True) t.start() return SessionWatcher(id_, t, client, event) def stop(self, force: bool = False): """ Send an event to stop the events stream and wait for the thread to finish. If you want to stop the stream savagely, you can use the optional arg: `force`. """ self._event.set() self._thread.join(timeout=0 if force else None) logging.debug(f"SSE stream client for session: {self.session_id} closed") @dataclass(slots=True) class Client: """ WhereIs API Client main class. Example: -------- cli = Client.from_env() sessions = cli.get_sessions() """ base_url: str email: str password: str session: Session = field(init=False) sessions_watcher: dict[UUID, SessionWatcher] = field( default_factory=dict, init=False ) def _login(self) -> WhereIsException | None: """Get the access token and store it in the `Session` header""" data = { "email": self.email, "password": self.password, } login_url = urljoin(self.base_url, "/auth/token/") logging.info(f"login: {login_url}") res = self.session.post(login_url, json=data) if res.status_code >= 400: raise WhereIsException(login_url, res) access_token = res.json()["access"] self.session.headers.update({"Authorization": f"Bearer {access_token}"}) return None def _refresh(self) -> WhereIsException | None: """Refresh the access token and store it in the `Session` header""" refresh_url = urljoin(self.base_url, "/auth/refresh/") logging.info(f"refresh: {refresh_url}") res = self.session.post(refresh_url) if res.status_code >= 400: raise WhereIsException(refresh_url, res) access_token = res.json()["access"] self.session.headers.update({"Authorization": f"Bearer {access_token}"}) return None @staticmethod def _init_client( base_url: str, email: str, password: str, verify: bool = True ) -> "Client": cli = Client(base_url, email, password) cli.session = Session() cli.session.headers.update({"content-type": "application/json"}) if not verify: urllib3.disable_warnings() cli.session.verify = verify cli._login() logging.info(f"client successfully initialized for user: {cli.email}") return cli @classmethod def from_env(cls) -> "Client": """ Initialize the client from env variables (.env & global) and log in into the application. If the login fails, an exception is raised. """ env_data = { "WHEREIS_API_EMAIL": os.getenv("WHEREIS_API_EMAIL", ""), "WHEREIS_API_PASSWORD": os.getenv("WHEREIS_API_PASSWORD", ""), "WHEREIS_API_BASE_URL": os.getenv("WHEREIS_API_BASE_URL", API_DEFAULT_URL), } dotenv_data = dotenv_values() env_data.update(dotenv_data) # type: ignore return Client._init_client( env_data.get("WHEREIS_API_BASE_URL", ""), env_data.get("WHEREIS_API_EMAIL", ""), env_data.get("WHEREIS_API_PASSWORD", ""), ) @classmethod def from_creds( cls, base_url: str, email: str, password: str, verify: bool = False ) -> "Client": return Client._init_client(base_url, email, password, verify) @refresh() def _get_sessions_page(self, url: str) -> dict[str, Any] | WhereIsException: """Get paginate sessions.""" res = self.session.get(url) if res.status_code == 401: raise UnauthorizedException(url, res) if res.status_code >= 400: raise WhereIsException(url, res) return res.json() def get_sessions( self, search: str | None = None, ordering: OrderField | None = None ) -> list[dict[str, Any]] | WhereIsException: """ Get all user and public sessions. If an error occurred duning the API call an exception is raised. Params: ------- search: str, search sessions over username and description ordering: OrderField, ordering sessions by dates """ sessions_url = urljoin(self.base_url, "/sessions/") has_query_param = False if search is not None: if not has_query_param: sessions_url += "?" has_query_param = True sessions_url += f"search={search}" if ordering is not None: if not has_query_param: sessions_url += "?" else: sessions_url += "&" sessions_url += f"ordering={ordering.value}" logging.info(f"get sessions: {sessions_url}") lst_sessions = [] while sessions_url is not None: data = self._get_sessions_page(sessions_url) sessions_url = data.get("next") lst_sessions.extend(data.get("results", [])) return lst_sessions @refresh() def get_session(self, id_: UUID) -> list[dict[str, Any]] | WhereIsException: session_url = urljoin(self.base_url, f"/sessions/{id_}/") logging.info(f"get session: {session_url}") res = self.session.get(session_url) if res.status_code == 401: raise UnauthorizedException(session_url, res) if res.status_code >= 400: raise WhereIsException(session_url, res) return res.json() @refresh() def create_session( self, name: str, description: str | None = None, is_public: bool = False ) -> dict[str, Any] | WhereIsException: sessions_url = urljoin(self.base_url, "/sessions/") logging.info(f"create session: {sessions_url}") data = {"name": name, "description": description, "is_public": is_public} if name is None: data.pop("name") if description is None: data.pop("description") if is_public is None: data.pop("is_public") res = self.session.post( sessions_url, json=data, ) if res.status_code == 401: raise UnauthorizedException(sessions_url, res) if res.status_code >= 400: raise WhereIsException(sessions_url, res) return res.json() @refresh() def update_session( self, id_: UUID, name: str | None = None, description: str | None = None, is_public: bool | None = None, ) -> dict[str, Any] | WhereIsException: session_url = urljoin(self.base_url, f"/sessions/{id_}/") logging.info(f"update session: {session_url}") data = {"name": name, "description": description, "is_public": is_public} if name is None: data.pop("name") if description is None: data.pop("description") if is_public is None: data.pop("is_public") res = self.session.patch(session_url, json=data) if res.status_code == 401: raise UnauthorizedException(session_url, res) if res.status_code >= 400: raise WhereIsException(session_url, res) return res.json() @refresh() def delete_session(self, id_: UUID) -> None | WhereIsException: """ Close and delete the session. NOTE: The GPS positions associated to the session are not deleted ! """ session_url = urljoin(self.base_url, f"/sessions/{id_}/") logging.info(f"delete session: {session_url}") res = self.session.delete(session_url) if res.status_code == 401: raise UnauthorizedException(session_url, res) if res.status_code >= 400: raise WhereIsException(session_url, res) return None @refresh() def update_session_users( self, id_: UUID, users: list[UUID] ) -> dict[str, Any] | WhereIsException: """ Update users sessions. WARN: An empty users list parameter cleans all users. """ session_url = urljoin(self.base_url, f"/sessions/{id_}/users/") logging.info(f"update session users: {session_url}") res = self.session.post(session_url, json={"users": users}) if res.status_code == 401: raise UnauthorizedException(session_url, res) if res.status_code >= 400: raise WhereIsException(session_url, res) return res.json() @refresh() def close_session(self, id_: UUID) -> dict[str, Any] | WhereIsException: """Close the DEFINITIVELY the session. Users can't be added anymore.""" session_url = urljoin(self.base_url, f"/sessions/{id_}/close/") logging.info(f"close session: {session_url}") res = self.session.post(session_url) if res.status_code == 401: raise UnauthorizedException(session_url, res) if res.status_code >= 400: raise WhereIsException(session_url, res) return res.json() def watch_session_events( self, id_: UUID, callback: Callable[[str], None] | None = None ): """ Watch session events through an SSE client. It will launch a daemon thread, listening for incoming events. You can use the `callback` optional argument to pass a callable to deal with the events. Example: -------- def treat_events(evt: str): # your instructions print(evt) cli.watch_session_events("session-id", treat_events) NOTE: You have to manually manage the connection error (IO, Authentication, etc...) For authentication error, you'll receive this kind of event: { "condition": "forbidden", "text": "Permission denied to channels: session_session-id", "channels": ["session_session-id"] } """ # noqa if self.sessions_watcher.get(id_) is not None: logging.warning(f"you're already watching session events: {id_}") return sw = SessionWatcher.from_session_id( self.base_url, id_, self.session.headers, # type: ignore callback, self.session.verify, # type: ignore ) logging.info(f"session events (id: {id_}) watcher started") self.sessions_watcher[id_] = sw def stop_watch_session(self, id_: UUID, force: bool = False): """ Stop watching events for a session. Use `force` optional argument to kill the watcher instead of waiting for a graceful stop. """ if (sw := self.sessions_watcher.get(id_)) is not None: sw.stop(force) del self.sessions_watcher[id_] logging.info(f"session events (id: {id_}) watcher stopped") @refresh() def get_wstokens(self) -> list[dict[str, Any]] | WhereIsException: wstoken_url = urljoin(self.base_url, "/auth/ws-token/") logging.info(f"get ws token: {wstoken_url}") res = self.session.get(wstoken_url) if res.status_code == 401: raise UnauthorizedException(wstoken_url, res) if res.status_code >= 400: raise WhereIsException(wstoken_url, res) return res.json() @refresh() def create_wstoken(self) -> dict[str, Any] | WhereIsException: """ Create a websocket JWT to authenticate your real time connection. NOTE: only one, and only one ws token per user is allowed. If it expired, delete it and create a new one. """ wstoken_url = urljoin(self.base_url, "/auth/ws-token/") logging.info(f"create ws token: {wstoken_url}") res = self.session.post(wstoken_url) if res.status_code == 401: raise UnauthorizedException(wstoken_url, res) if res.status_code >= 400: raise WhereIsException(wstoken_url, res) return res.json() @refresh() def delete_wstoken(self, id_: UUID) -> None | WhereIsException: wstoken_url = urljoin(self.base_url, f"/auth/ws-token/{id_}/") logging.info(f"delete ws token: {wstoken_url}") res = self.session.delete(wstoken_url) if res.status_code == 401: raise UnauthorizedException(wstoken_url, res) if res.status_code >= 400: raise WhereIsException(wstoken_url, res) return None @refresh() def _get_paginate_gps_positions( self, url: str ) -> dict[str, Any] | WhereIsException: res = self.session.get(url) if res.status_code == 401: raise UnauthorizedException(url, res) if res.status_code >= 400: raise WhereIsException(url, res) return res.json() def get_gps_positions( self, date_start: str | None = None, date_end: str | None = None, ) -> list[dict[str, Any]] | ValueError | WhereIsException: """ Gets GPS positions data. You can get GPS positions filtered by date interval using optionals arguments `date_start` and `date_end`: - `date_start`: [date_start,] - `date_end`: [date_end,] - `date_start` & `date_end`: [date_start,date_end] The dates formats must be in any valid ISO 8601 format otherwise it will raise a ValueError. """ gps_url = urljoin(self.base_url, "/gps/positions/") lst_gps_positions: list[dict[str, Any]] = [] if date_start: ds = dt.fromisoformat(date_start) gps_url += f"?date_start={ds.isoformat()}" if date_end: de = dt.fromisoformat(date_end) if date_start: gps_url += f"&date_end={de.isoformat()}" else: gps_url += f"?date_end={de.isoformat()}" while gps_url is not None: # pagination api next url returns http scheme instead of https if self.base_url.startswith("https"): gps_url = gps_url.replace("http://", "https://") logging.info(f"get gps data from: {gps_url}") data = self._get_paginate_gps_positions(gps_url) lst_gps_positions.extend([d for d in data["results"]]) gps_url = data["next"] return lst_gps_positions