diff --git a/.env.example b/.env.example index e69de29..b37f586 100644 --- a/.env.example +++ b/.env.example @@ -0,0 +1,5 @@ +WHEREIS_API_BASE_URL=https://api-whereis.thegux.fr +WHEREIS_API_EMAIL= +WHEREIS_API_PASSWORD= + +WHEREIS_CERT_VERIFY=true \ No newline at end of file diff --git a/Makefile b/Makefile index 7d31afe..29e7ba2 100644 --- a/Makefile +++ b/Makefile @@ -22,4 +22,4 @@ build: check $(PYTHON) -m hatch -v build -t wheel publish: build - $(PYTHON) -m twine upload --repository gitea dist/*.whl \ No newline at end of file + $(PYTHON) -m twine upload --repository whereis-client dist/*.whl \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..62754b3 --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# whereis-client + +A simple Python client library to interact easyly with the WhereIs REST API. + +## Install + +```bash +pip install --index-url https://gitea.thegux.fr/api/packages/rmanach/pypi/simple/ whereis-client== --extra-index-url https://pypi.org/simple +``` + +## How to ? +In order to use the client you need to provide some environments variables defined in `.env.example`. You can either: +- Copy the `.env.example` into `.env` and feed the variables (`.env` should be next to your main script using the lib) +- Add those environments variables to your environment profile. + +Once it's done, you're good to use the client. + +```python +from whereis_client import Client + +cli = Client.from_env() +lst_gps_positions = cli.get_gps_positions( + date_start="2022-12-25", date_end="2022-12-30" +) +``` + +For some code samples on how to use the client, take a look at the [main.py](main.py) sample script. + +Enjoy ! + +## Contact + +If you have any issues, feel free to contact **admin@thegux.fr** for fixes or upgrades. diff --git a/main.py b/main.py index 7040273..5da3376 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,10 @@ import logging import sys +import json +import time -from src import VERSION +from src import VERSION, Client, OrderField +from src.exceptions import WhereIsException, UnauthorizedException stdout_handler = logging.StreamHandler(stream=sys.stdout) @@ -12,5 +15,113 @@ logging.basicConfig( ) if __name__ == "__main__": + """ + This a sample script to demonstrate how to deal with WhereIs API client. + + All you need is to provide some environments variables describe in the `.env.example`. + You can: + - Copy `.env.example` into `.env` and feed env variables or, + - Directly export those mandatories variables + + Once its done, you can initialize the Client: + + cli = Client.from_env() + + and use all the available methods to interact with the WhereIs API. + + Sessions: + - create_session + - get_sessions + - get_session + - update_session + - delete_session + - update_session_users + - close_session + - watch_session_events + - stop_watch_session + WS token: + - get_wstokens + - create_wstoken + - delete_wstoken + GPS positions: + - get_gps_positions + """ logging.info(f"WhereIs client v{VERSION}") + # initialize the client + try: + cli = Client.from_env() + except WhereIsException as e: + logging.error(f"unable to initialize WhereIs API client: {e}") + except Exception as e: + logging.error("unexpected error occurred while initializing client", exc_info=True) + + # get ws tokens + tokens = cli.get_wstokens() + print(json.dumps(tokens, indent=2)) + + # delete ws token + if len(tokens) > 0: + token = cli.delete_wstoken(tokens[0].get("id")) + print(json.dumps(tokens, indent=2)) + + # create ws token + try: + token = cli.create_wstoken() + except WhereIsException as e: + logging.error(f"error occurred while creating a ws token, status code: {e.error_code}") + print(json.dumps(e.content, indent=2)) + + # retrieve all user/public sessions + sessions = cli.get_sessions() + print(json.dumps(sessions, indent=2)) + + # create and update a session (must have `Streamer` role) + try: + session = cli.create_session(name="RUN-01") + session = cli.update_session(session.get("id"), name="RUN-02", description="My second run") + except WhereIsException as e: + logging.error(f"error occurred while creating/updating a session, status code: {e.error_code}") + print(json.dumps(e.content, indent=2)) + + # retrieve all user/public sessions with ordering + sessions = cli.get_sessions(ordering=OrderField.AscDateStart) + print(json.dumps(sessions, indent=2)) + + # retrieve all user/public sessions with search + sessions = cli.get_sessions(search="run", ordering=OrderField.AscDateStart) + print(json.dumps(sessions, indent=2)) + + # retrieve session by id + session = cli.get_session(session.get("id")) + print(json.dumps(session, indent=2)) + + # update session users + try: + session = cli.update_session_users(session.get("id"), []) + except WhereIsException as e: + logging.error(f"error occurred while updating users session, status code: {e.error_code}") + print(json.dumps(e.content, indent=2)) + + # close a session + try: + session = cli.close_session("does-not-exist") + except WhereIsException as e: + logging.error(f"error occurred while closing session, status code: {e.error_code}") + print(json.dumps(e.content, indent=2)) + + # get session events + cli.watch_session_events(session.get("id")) + + # get gps positions from "2024-12-25T23:00:00" to infinity... + gps_positions = cli.get_gps_positions(date_start="2024-12-25T23:00:00") + print(json.dumps(gps_positions, indent=2)) + + # close the session + session = cli.close_session(session.get("id")) + + # delete a session + cli.delete_session(session.get("id")) + + # stop session events watcher + cli.stop_watch_session(session.get("id"), force=True) diff --git a/pyproject.toml b/pyproject.toml index 8b96326..fe053d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,8 @@ dynamic = ["version"] description = "WhereIs API client library" dependencies = [ "requests==2.32.3", + "python-dotenv==1.0.1", + "sseclient-py==1.8.0", ] [tool.hatch.version] @@ -20,7 +22,7 @@ packages = ["src"] only-include = ["src"] [tool.hatch.build.targets.wheel.sources] -"src" = "whereis-client" +"src" = "whereis_client" [tool.ruff] select = ["E", "F", "I"] diff --git a/requirements.txt b/requirements.txt index 1786051..b67a478 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ requests==2.32.3 -python-dotenv==1.0.1 \ No newline at end of file +python-dotenv==1.0.1 +sseclient-py==1.8.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 1cf6267..d714875 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1 +1,5 @@ +from .client import Client, OrderField + +__all__ = ["Client", "OrderField"] + VERSION = "0.1.0" diff --git a/src/client.py b/src/client.py new file mode 100644 index 0000000..6fbd58b --- /dev/null +++ b/src/client.py @@ -0,0 +1,550 @@ +import logging +import os +from dataclasses import dataclass, field +from datetime import datetime as dt +from enum import Enum +from threading import Event, Thread +from typing import Any, Callable +from urllib.parse import urljoin +from uuid import UUID + +import requests +import urllib3 +from dotenv import dotenv_values +from requests import Session +from sseclient import SSEClient + +from .exceptions import UnauthorizedException, WhereIsException + +API_DEFAULT_URL = "https://api-whereis.thegux.fr" + +__all__ = ["Client", "OrderField"] + + +def refresh(): + """ + Catch 401 status code (UnauthorizedException) + refresh the access token and retry. + """ + + def decorator(func): + def wrapper(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], Client): + for i in range(2): + try: + return func(*args, **kwargs) + except UnauthorizedException as e: + if i == 1: # second attempt + logging.error( + "second call attempt failed after refreshing token" + ) + raise e + logging.warning("refresh access token...") + args[0]._refresh() + + return wrapper + + return decorator + + +class OrderField(Enum): + """ + Ordering query param available fields. + + - AscDateStart: date start ascending order + - DesDateStart: date start descending order + - AscDateEnd: date end ascending order + - DesDateEnd: date end descending order + """ + + AscDateStart = "date_start" + DesDateStart = "-date_start" + AscDateEnd = "date_end" + DesDateEnd = "-date_end" + + +@dataclass(frozen=True, slots=True) +class SessionWatcher: + """ + Handle the SSE client connection in a daemon thread. + + NOTE: Should not be instanciated directly, use + the `Client` instead. + """ + + session_id: UUID + _thread: Thread + _client: SSEClient + _event: Event + + @classmethod + def from_session_id( + cls, + base_url: str, + id_: UUID, + headers: dict[str, Any], + callback: Callable[[str], None] | None = None, + verify: bool = True, + ) -> "SessionWatcher": + session_url = urljoin(base_url, f"/sessions/{id_}/events/") + + headers = {**headers, "Accept": "text/event-stream"} + resp = requests.get(session_url, stream=True, headers=headers, verify=verify) + client = SSEClient(resp) # type: ignore + event = Event() + + def _job(): + logging.debug(f"SSE client daemon started for session: {id_}") + + for evt in client.events(): + logging.debug(f"event received for session: {id_}") + if callback is not None: + callback(evt.data) + + if event.is_set(): + logging.debug(f"SSE client daemon stopped for session: {id_}") + break + + t = Thread(target=_job, daemon=True) + t.start() + + return SessionWatcher(id_, t, client, event) + + def stop(self, force: bool = False): + """ + Send an event to stop the events stream and wait + for the thread to finish. + + If you want to stop the stream savagely, you can use + the optional arg: `force`. + """ + self._event.set() + self._thread.join(timeout=0 if force else None) + logging.debug(f"SSE stream client for session: {self.session_id} closed") + + +@dataclass(slots=True) +class Client: + """ + WhereIs API Client main class. + + Example: + -------- + cli = Client.from_env() + sessions = cli.get_sessions() + """ + + base_url: str + email: str + password: str + + session: Session = field(init=False) + + sessions_watcher: dict[UUID, SessionWatcher] = field( + default_factory=dict, init=False + ) + + def _login(self) -> WhereIsException | None: + """Get the access token and store it in the `Session` header""" + data = { + "email": self.email, + "password": self.password, + } + + login_url = urljoin(self.base_url, "/auth/token/") + logging.info(f"login: {login_url}") + + res = self.session.post(login_url, json=data) + if res.status_code >= 400: + raise WhereIsException(login_url, res) + + access_token = res.json()["access"] + self.session.headers.update({"Authorization": f"Bearer {access_token}"}) + + return None + + def _refresh(self) -> WhereIsException | None: + """Refresh the access token and store it in the `Session` header""" + refresh_url = urljoin(self.base_url, "/auth/refresh/") + logging.info(f"refresh: {refresh_url}") + + res = self.session.post(refresh_url) + if res.status_code >= 400: + raise WhereIsException(refresh_url, res) + + access_token = res.json()["access"] + self.session.headers.update({"Authorization": f"Bearer {access_token}"}) + + return None + + @classmethod + def from_env(cls) -> "Client": + """ + Initialize the client from env variables (.env & global) and + log in into the application. + + If the login fails, an exception is raised. + """ + env_data = { + "WHEREIS_API_EMAIL": os.getenv("WHEREIS_API_EMAIL", ""), + "WHEREIS_API_PASSWORD": os.getenv("WHEREIS_API_PASSWORD", ""), + "WHEREIS_API_BASE_URL": os.getenv("WHEREIS_API_BASE_URL", API_DEFAULT_URL), + } + dotenv_data = dotenv_values() + env_data.update(dotenv_data) # type: ignore + + cli = Client( + env_data.get("WHEREIS_API_BASE_URL", ""), + env_data.get("WHEREIS_API_EMAIL", ""), + env_data.get("WHEREIS_API_PASSWORD", ""), + ) + cli.session = Session() + cli.session.headers.update({"content-type": "application/json"}) + + is_cert_verify = env_data.get("WHEREIS_CERT_VERIFY", "") != "false" + if not is_cert_verify: + urllib3.disable_warnings() + + cli.session.verify = is_cert_verify + + cli._login() + + logging.info(f"client successfully initialized for user: {cli.email}") + return cli + + @refresh() + def _get_sessions_page(self, url: str) -> dict[str, Any] | WhereIsException: + """Get paginate sessions.""" + res = self.session.get(url) + if res.status_code == 401: + raise UnauthorizedException(url, res) + + if res.status_code >= 400: + raise WhereIsException(url, res) + + return res.json() + + def get_sessions( + self, search: str | None = None, ordering: OrderField | None = None + ) -> list[dict[str, Any]] | WhereIsException: + """ + Get all user and public sessions. + If an error occurred duning the API call an exception is raised. + + Params: + ------- + search: str, search sessions over username and description + ordering: OrderField, ordering sessions by dates + """ + sessions_url = urljoin(self.base_url, "/sessions/") + + has_query_param = False + if search is not None: + if not has_query_param: + sessions_url += "?" + has_query_param = True + sessions_url += f"search={search}" + + if ordering is not None: + if not has_query_param: + sessions_url += "?" + else: + sessions_url += "&" + sessions_url += f"ordering={ordering.value}" + + logging.info(f"get sessions: {sessions_url}") + + lst_sessions = [] + while sessions_url is not None: + data = self._get_sessions_page(sessions_url) + + sessions_url = data.get("next") + lst_sessions.extend(data.get("results", [])) + + return lst_sessions + + @refresh() + def get_session(self, id_: UUID) -> list[dict[str, Any]] | WhereIsException: + session_url = urljoin(self.base_url, f"/sessions/{id_}/") + logging.info(f"get session: {session_url}") + + res = self.session.get(session_url) + if res.status_code == 401: + raise UnauthorizedException(session_url, res) + + if res.status_code >= 400: + raise WhereIsException(session_url, res) + + return res.json() + + @refresh() + def create_session( + self, name: str, description: str | None = None, is_public: bool = False + ) -> dict[str, Any] | WhereIsException: + sessions_url = urljoin(self.base_url, "/sessions/") + logging.info(f"create session: {sessions_url}") + + data = {"name": name, "description": description, "is_public": is_public} + if name is None: + data.pop("name") + if description is None: + data.pop("description") + if is_public is None: + data.pop("is_public") + + res = self.session.post( + sessions_url, + json=data, + ) + + if res.status_code == 401: + raise UnauthorizedException(sessions_url, res) + + if res.status_code >= 400: + raise WhereIsException(sessions_url, res) + + return res.json() + + @refresh() + def update_session( + self, + id_: UUID, + name: str | None = None, + description: str | None = None, + is_public: bool | None = None, + ) -> dict[str, Any] | WhereIsException: + session_url = urljoin(self.base_url, f"/sessions/{id_}/") + logging.info(f"update session: {session_url}") + + data = {"name": name, "description": description, "is_public": is_public} + if name is None: + data.pop("name") + if description is None: + data.pop("description") + if is_public is None: + data.pop("is_public") + + res = self.session.patch(session_url, json=data) + + if res.status_code == 401: + raise UnauthorizedException(session_url, res) + + if res.status_code >= 400: + raise WhereIsException(session_url, res) + + return res.json() + + @refresh() + def delete_session(self, id_: UUID) -> None | WhereIsException: + """ + Close and delete the session. + + NOTE: The GPS positions associated to the session are not deleted ! + """ + session_url = urljoin(self.base_url, f"/sessions/{id_}/") + logging.info(f"delete session: {session_url}") + + res = self.session.delete(session_url) + + if res.status_code == 401: + raise UnauthorizedException(session_url, res) + + if res.status_code >= 400: + raise WhereIsException(session_url, res) + + return None + + @refresh() + def update_session_users( + self, id_: UUID, users: list[UUID] + ) -> dict[str, Any] | WhereIsException: + """ + Update users sessions. + + WARN: An empty users list parameter cleans all users. + """ + session_url = urljoin(self.base_url, f"/sessions/{id_}/users/") + logging.info(f"update session users: {session_url}") + + res = self.session.post(session_url, json={"users": users}) + + if res.status_code == 401: + raise UnauthorizedException(session_url, res) + + if res.status_code >= 400: + raise WhereIsException(session_url, res) + + return res.json() + + @refresh() + def close_session(self, id_: UUID) -> dict[str, Any] | WhereIsException: + """Close the DEFINITIVELY the session. Users can't be added anymore.""" + session_url = urljoin(self.base_url, f"/sessions/{id_}/close/") + logging.info(f"close session: {session_url}") + + res = self.session.post(session_url) + + if res.status_code == 401: + raise UnauthorizedException(session_url, res) + + if res.status_code >= 400: + raise WhereIsException(session_url, res) + + return res.json() + + def watch_session_events( + self, id_: UUID, callback: Callable[[str], None] | None = None + ): + """ + Watch session events through an SSE client. + + It will launch a daemon thread, listening for incoming events. + You can use the `callback` optional argument to pass a callable + to deal with the events. + + Example: + -------- + def treat_events(evt: str): + # your instructions + print(evt) + + cli.watch_session_events("session-id", treat_events) + + NOTE: You have to manually manage the connection error (IO, Authentication, etc...) + For authentication error, you'll receive this kind of event: + { + "condition": "forbidden", + "text": "Permission denied to channels: session_session-id", + "channels": ["session_session-id"] + } + """ # noqa + if self.sessions_watcher.get(id_) is not None: + logging.warning(f"you're already watching session events: {id_}") + return + + sw = SessionWatcher.from_session_id( + self.base_url, + id_, + self.session.headers, # type: ignore + callback, + self.session.verify, # type: ignore + ) + + logging.info(f"session events (id: {id_}) watcher started") + self.sessions_watcher[id_] = sw + + def stop_watch_session(self, id_: UUID, force: bool = False): + """ + Stop watching events for a session. + + Use `force` optional argument to kill the watcher + instead of waiting for a graceful stop. + """ + if (sw := self.sessions_watcher.get(id_)) is not None: + sw.stop(force) + del self.sessions_watcher[id_] + logging.info(f"session events (id: {id_}) watcher stopped") + + @refresh() + def get_wstokens(self) -> list[dict[str, Any]] | WhereIsException: + wstoken_url = urljoin(self.base_url, "/auth/ws-token/") + logging.info(f"get ws token: {wstoken_url}") + + res = self.session.get(wstoken_url) + + if res.status_code == 401: + raise UnauthorizedException(wstoken_url, res) + + if res.status_code >= 400: + raise WhereIsException(wstoken_url, res) + + return res.json() + + @refresh() + def create_wstoken(self) -> dict[str, Any] | WhereIsException: + """ + Create a websocket JWT to authenticate your real time connection. + + NOTE: only one, and only one ws token per user is allowed. + If it expired, delete it and create a new one. + """ + wstoken_url = urljoin(self.base_url, "/auth/ws-token/") + logging.info(f"create ws token: {wstoken_url}") + + res = self.session.post(wstoken_url) + + if res.status_code == 401: + raise UnauthorizedException(wstoken_url, res) + + if res.status_code >= 400: + raise WhereIsException(wstoken_url, res) + + return res.json() + + @refresh() + def delete_wstoken(self, id_: UUID) -> None | WhereIsException: + wstoken_url = urljoin(self.base_url, f"/auth/ws-token/{id_}/") + logging.info(f"delete ws token: {wstoken_url}") + + res = self.session.delete(wstoken_url) + + if res.status_code == 401: + raise UnauthorizedException(wstoken_url, res) + + if res.status_code >= 400: + raise WhereIsException(wstoken_url, res) + + return None + + @refresh() + def _get_paginate_gps_positions( + self, url: str + ) -> dict[str, Any] | WhereIsException: + res = self.session.get(url) + + if res.status_code == 401: + raise UnauthorizedException(url, res) + + if res.status_code >= 400: + raise WhereIsException(url, res) + + return res.json() + + def get_gps_positions( + self, + date_start: str | None = None, + date_end: str | None = None, + ) -> list[dict[str, Any]] | ValueError | WhereIsException: + """ + Gets GPS positions data. + + You can get GPS positions filtered by date interval using + optionals arguments `date_start` and `date_end`: + - `date_start`: [date_start,] + - `date_end`: [date_end,] + - `date_start` & `date_end`: [date_start,date_end] + + The dates formats must be in any valid ISO 8601 format otherwise + it will raise a ValueError. + """ + gps_url = urljoin(self.base_url, "/gps/positions/") + lst_gps_positions: list[dict[str, Any]] = [] + + if date_start: + ds = dt.fromisoformat(date_start) + gps_url += f"?date_start={ds.isoformat()}" + if date_end: + de = dt.fromisoformat(date_end) + if date_start: + gps_url += f"&date_end={de.isoformat()}" + else: + gps_url += f"?date_end={de.isoformat()}" + + while gps_url is not None: + logging.info(f"get gps data from: {gps_url}") + + data = self._get_paginate_gps_positions(gps_url) + lst_gps_positions.extend([d for d in data["results"]]) + gps_url = data["next"] + + return lst_gps_positions diff --git a/src/exceptions.py b/src/exceptions.py new file mode 100644 index 0000000..9eb5cc3 --- /dev/null +++ b/src/exceptions.py @@ -0,0 +1,23 @@ +from requests import Response + +__all__ = ["WhereIsException", "UnauthorizedException"] + + +class WhereIsException(Exception): + """Handle all WhereIs API errors.""" + + def __init__(self, url: str, response: Response): + self.url = url + try: + self.content = response.json() + except Exception: + self.content = response.content.decode() + super().__init__(self.content) + self.error_code = response.status_code + + def __str__(self): + return f"error calling: {self.url} - {self.error_code} - {self.content}" + + +class UnauthorizedException(WhereIsException): + pass