whereis-client/src/client.py
2024-12-18 11:30:31 +01:00

547 lines
17 KiB
Python

import logging
import os
from dataclasses import dataclass, field
from datetime import datetime as dt
from enum import Enum
from threading import Event, Thread
from typing import Any, Callable
from urllib.parse import urljoin
from uuid import UUID
import requests
import urllib3
from dotenv import dotenv_values
from requests import Session
from sseclient import SSEClient
from .exceptions import UnauthorizedException, WhereIsException
API_DEFAULT_URL = "https://api-whereis.thegux.fr"
__all__ = ["Client", "OrderField"]
def refresh():
"""
Catch 401 status code (UnauthorizedException)
and refresh the access token and retry.
"""
def decorator(func):
def wrapper(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], Client):
for i in range(2):
try:
return func(*args, **kwargs)
except UnauthorizedException as e:
if i == 1: # second attempt
logging.error(
"second call attempt failed after refreshing token"
)
raise e
logging.warning("refresh access token...")
args[0]._refresh()
return wrapper
return decorator
class OrderField(Enum):
"""
Ordering query param available fields.
- AscDateStart: date start ascending order
- DesDateStart: date start descending order
- AscDateEnd: date end ascending order
- DesDateEnd: date end descending order
"""
AscDateStart = "date_start"
DesDateStart = "-date_start"
AscDateEnd = "date_end"
DesDateEnd = "-date_end"
@dataclass(frozen=True, slots=True)
class SessionWatcher:
"""
Handle the SSE client connection in a daemon thread.
NOTE: Should not be instanciated directly, use
the `Client` instead.
"""
session_id: UUID
_thread: Thread
_client: SSEClient
_event: Event
@classmethod
def from_session_id(
cls,
base_url: str,
id_: UUID,
headers: dict[str, Any],
callback: Callable[[str], None] | None = None,
verify: bool = True,
) -> "SessionWatcher":
session_url = urljoin(base_url, f"/sessions/{id_}/events/")
headers = {**headers, "Accept": "text/event-stream"}
resp = requests.get(session_url, stream=True, headers=headers, verify=verify)
client = SSEClient(resp) # type: ignore
event = Event()
def _job():
logging.debug(f"SSE client daemon started for session: {id_}")
for evt in client.events():
logging.debug(f"event received for session: {id_}")
if callback is not None:
callback(evt.data)
if event.is_set():
logging.debug(f"SSE client daemon stopped for session: {id_}")
break
t = Thread(target=_job, daemon=True)
t.start()
return SessionWatcher(id_, t, client, event)
def stop(self, force: bool = False):
"""
Send an event to stop the events stream and wait
for the thread to finish.
If you want to stop the stream savagely, you can use
the optional arg: `force`.
"""
self._event.set()
self._thread.join(timeout=0 if force else None)
logging.debug(f"SSE stream client for session: {self.session_id} closed")
@dataclass(slots=True)
class Client:
"""
WhereIs API Client main class.
Example:
--------
cli = Client.from_env()
sessions = cli.get_sessions()
"""
base_url: str
email: str
password: str
session: Session = field(init=False)
sessions_watcher: dict[UUID, SessionWatcher] = field(
default_factory=dict, init=False
)
def _login(self) -> WhereIsException | None:
"""Get the access token and store it in the `Session` header"""
data = {
"email": self.email,
"password": self.password,
}
login_url = urljoin(self.base_url, "/auth/token/")
logging.info(f"login: {login_url}")
res = self.session.post(login_url, json=data)
if res.status_code >= 400:
raise WhereIsException(login_url, res)
access_token = res.json()["access"]
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
return None
def _refresh(self) -> WhereIsException | None:
"""Refresh the access token and store it in the `Session` header"""
refresh_url = urljoin(self.base_url, "/auth/refresh/")
logging.info(f"refresh: {refresh_url}")
res = self.session.post(refresh_url)
if res.status_code >= 400:
raise WhereIsException(refresh_url, res)
access_token = res.json()["access"]
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
return None
@classmethod
def from_env(cls) -> "Client":
"""
Initialize the client from env variables (.env & global) and
log in into the application.
If the login fails, and exception is raised.
"""
env_data = {
"WHEREIS_API_EMAIL": os.getenv("WHEREIS_API_EMAIL", ""),
"WHEREIS_API_PASSWORD": os.getenv("WHEREIS_API_PASSWORD", ""),
"WHEREIS_API_BASE_URL": os.getenv("WHEREIS_API_BASE_URL", API_DEFAULT_URL),
}
dotenv_data = dotenv_values()
env_data.update(dotenv_data) # type: ignore
cli = Client(
env_data.get("WHEREIS_API_BASE_URL", ""),
env_data.get("WHEREIS_API_EMAIL", ""),
env_data.get("WHEREIS_API_PASSWORD", ""),
)
cli.session = Session()
cli.session.headers.update({"content-type": "application/json"})
is_cert_verify = env_data.get("WHEREIS_CERT_VERIFY", "") != "false"
if not is_cert_verify:
urllib3.disable_warnings()
cli.session.verify = is_cert_verify
cli._login()
logging.info(f"client successfully initialized for user: {cli.email}")
return cli
@refresh()
def _get_sessions_page(self, url: str) -> dict[str, Any] | WhereIsException:
"""Get paginate sessions."""
res = self.session.get(url)
if res.status_code == 401:
raise UnauthorizedException(url, res)
if res.status_code >= 400:
raise WhereIsException(url, res)
return res.json()
def get_sessions(
self, search: str | None = None, ordering: OrderField | None = None
) -> list[dict[str, Any]] | WhereIsException:
"""
Get all user and public sessions.
If an error occurred duning the API call an exception is raised.
Params:
-------
search: str, search sessions over username and description
ordering: OrderField, ordering sessions by dates
"""
sessions_url = urljoin(self.base_url, "/sessions/")
has_query_param = False
if search is not None:
if not has_query_param:
sessions_url += "?"
has_query_param = True
sessions_url += f"search={search}"
if ordering is not None:
if not has_query_param:
sessions_url += "?"
else:
sessions_url += "&"
sessions_url += f"ordering={ordering.value}"
logging.info(f"get sessions: {sessions_url}")
lst_sessions = []
while sessions_url is not None:
data = self._get_sessions_page(sessions_url)
sessions_url = data.get("next")
lst_sessions.extend(data.get("results", []))
return lst_sessions
@refresh()
def get_session(self, id_: UUID) -> list[dict[str, Any]] | WhereIsException:
session_url = urljoin(self.base_url, f"/sessions/{id_}/")
logging.info(f"get session: {session_url}")
res = self.session.get(session_url)
if res.status_code == 401:
raise UnauthorizedException(session_url, res)
if res.status_code >= 400:
raise WhereIsException(session_url, res)
return res.json()
@refresh()
def create_session(
self, name: str, description: str | None = None, is_public: bool = False
) -> dict[str, Any] | WhereIsException:
sessions_url = urljoin(self.base_url, "/sessions/")
logging.info(f"create session: {sessions_url}")
data = {"name": name, "description": description, "is_public": is_public}
if name is None:
data.pop("name")
if description is None:
data.pop("description")
if is_public is None:
data.pop("is_public")
res = self.session.post(
sessions_url,
json=data,
)
if res.status_code == 401:
raise UnauthorizedException(sessions_url, res)
if res.status_code >= 400:
raise WhereIsException(sessions_url, res)
return res.json()
@refresh()
def update_session(
self,
id_: UUID,
name: str | None = None,
description: str | None = None,
is_public: bool | None = None,
) -> dict[str, Any] | WhereIsException:
session_url = urljoin(self.base_url, f"/sessions/{id_}/")
logging.info(f"update session: {session_url}")
data = {"name": name, "description": description, "is_public": is_public}
if name is None:
data.pop("name")
if description is None:
data.pop("description")
if is_public is None:
data.pop("is_public")
res = self.session.patch(session_url, json=data)
if res.status_code == 401:
raise UnauthorizedException(session_url, res)
if res.status_code >= 400:
raise WhereIsException(session_url, res)
return res.json()
@refresh()
def delete_session(self, id_: UUID) -> None | WhereIsException:
"""
Close and delete the session. Users can't be added anymore.
NOTE: The GPS positions associated to the session are not deleted !
"""
session_url = urljoin(self.base_url, f"/sessions/{id_}/")
logging.info(f"delete session: {session_url}")
res = self.session.delete(session_url)
if res.status_code == 401:
raise UnauthorizedException(session_url, res)
if res.status_code >= 400:
raise WhereIsException(session_url, res)
return None
@refresh()
def update_session_users(
self, id_: UUID, users: list[UUID]
) -> dict[str, Any] | WhereIsException:
"""
Update users sessions.
WARN: An empty users list parameter cleans all users.
"""
session_url = urljoin(self.base_url, f"/sessions/{id_}/users/")
logging.info(f"update session users: {session_url}")
res = self.session.post(session_url, json={"users": users})
if res.status_code == 401:
raise UnauthorizedException(session_url, res)
if res.status_code >= 400:
raise WhereIsException(session_url, res)
return res.json()
@refresh()
def close_session(self, id_: UUID) -> dict[str, Any] | WhereIsException:
session_url = urljoin(self.base_url, f"/sessions/{id_}/close/")
logging.info(f"close session: {session_url}")
res = self.session.post(session_url)
if res.status_code == 401:
raise UnauthorizedException(session_url, res)
if res.status_code >= 400:
raise WhereIsException(session_url, res)
return res.json()
def watch_session_events(
self, id_: UUID, callback: Callable[[str], None] | None = None
):
"""
Watch session events through an SSE client.
It will launch a daemon thread, listening for incoming events.
You can use the `callback` optional argument to pass a callable
to deal with the events.
Example:
--------
def treat_events(evt: str):
# your instructions
print(evt)
cli.watch_session_events("session-id", treat_events)
NOTE: You have to manually manage the connection error (IO, Authentication, etc...)
For authentication error, you'll receive this kind of event:
{
"condition": "forbidden",
"text": "Permission denied to channels: session_session-id",
"channels": ["session_session-id"]
}
""" # noqa
if self.sessions_watcher.get(id_) is not None:
logging.warning(f"you're already watching session events: {id_}")
return
sw = SessionWatcher.from_session_id(
self.base_url,
id_,
self.session.headers, # type: ignore
callback,
self.session.verify, # type: ignore
)
logging.info(f"session events (id: {id_}) watcher started")
self.sessions_watcher[id_] = sw
def stop_watch_session(self, id_: UUID, force: bool = False):
"""
Stop watching events for a session.
Use `force` optional argument to kill the watcher
instead of waiting for a graceful stop.
"""
if (sw := self.sessions_watcher.get(id_)) is not None:
sw.stop(force)
del self.sessions_watcher[id_]
logging.info(f"session events (id: {id_}) watcher stopped")
@refresh()
def get_wstokens(self) -> list[dict[str, Any]] | WhereIsException:
wstoken_url = urljoin(self.base_url, "/auth/ws-token/")
logging.info(f"get ws token: {wstoken_url}")
res = self.session.get(wstoken_url)
if res.status_code == 401:
raise UnauthorizedException(wstoken_url, res)
if res.status_code >= 400:
raise WhereIsException(wstoken_url, res)
return res.json()
@refresh()
def create_wstoken(self) -> dict[str, Any] | WhereIsException:
"""
Create a websocket JWT to authenticate your real time connection.
NOTE: only one, and only one ws token per user is allowed.
If it expired, delete it and create a new one.
"""
wstoken_url = urljoin(self.base_url, "/auth/ws-token/")
logging.info(f"create ws token: {wstoken_url}")
res = self.session.post(wstoken_url)
if res.status_code == 401:
raise UnauthorizedException(wstoken_url, res)
if res.status_code >= 400:
raise WhereIsException(wstoken_url, res)
return res.json()
@refresh()
def delete_wstoken(self, id_: UUID) -> None | WhereIsException:
wstoken_url = urljoin(self.base_url, f"/auth/ws-token/{id_}/")
logging.info(f"delete ws token: {wstoken_url}")
res = self.session.delete(wstoken_url)
if res.status_code == 401:
raise UnauthorizedException(wstoken_url, res)
if res.status_code >= 400:
raise WhereIsException(wstoken_url, res)
return None
@refresh()
def _get_paginate_gps_postions(self, url: str) -> dict[str, Any] | WhereIsException:
res = self.session.get(url)
if res.status_code == 401:
raise UnauthorizedException(url, res)
if res.status_code >= 400:
raise WhereIsException(url, res)
return res.json()
def get_gps_positions(
self,
date_start: str | None = None,
date_end: str | None = None,
) -> list[dict[str, Any]] | ValueError | WhereIsException:
"""
Gets GPS positions data.
You can get GPS positions filtered by date interval using
optionals arguments `date_start` and `date_end`:
- `date_start`: [date_start,]
- `date_end`: [date_end,]
- `date_start` & `date_end`: [date_start,date_end]
The dates formats must be in any valid ISO 8601 format otherwise
it will raise a ValueError.
"""
gps_url = urljoin(self.base_url, "/gps/positions/")
lst_gps_positions: list[dict[str, Any]] = []
if date_start:
ds = dt.fromisoformat(date_start)
gps_url += f"?date_start={ds.isoformat()}"
if date_end:
de = dt.fromisoformat(date_end)
if date_start:
gps_url += f"&date_end={de.isoformat()}"
else:
gps_url += f"?date_end={de.isoformat()}"
while gps_url is not None:
logging.info(f"get gps data from: {gps_url}")
data = self._get_paginate_gps_postions(gps_url)
lst_gps_positions.extend([d for d in data["results"]])
gps_url = data["next"]
return lst_gps_positions