diff --git a/main.py b/main.py index 60465d4..17e67d3 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import logging import sys import json +import time from src import VERSION, Client, OrderField from src.exceptions import WhereIsException, UnauthorizedException @@ -33,6 +34,10 @@ if __name__ == "__main__": - get_session - update_session - delete_session + - update_session_users + - close_session + - watch_session_events + - stop_watch_session """ logging.info(f"WhereIs client v{VERSION}") @@ -77,10 +82,22 @@ if __name__ == "__main__": # close a session try: - session = cli.close_session("fqsfsdqf") + session = cli.close_session("does-not-exist") except WhereIsException as e: logging.error(f"error occurred while updating users session, status code: {e.error_code}") print(json.dumps(e.content, indent=2)) + # get session events + cli.watch_session_events(session.get("id")) + + # doing your stuff... + time.sleep(5) + + # close the session + session = cli.close_session(session.get("id")) + # delete a session cli.delete_session(session.get("id")) + + # stop session events watcher + cli.stop_watch_session(session.get("id")) diff --git a/requirements.txt b/requirements.txt index 1786051..b67a478 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ requests==2.32.3 -python-dotenv==1.0.1 \ No newline at end of file +python-dotenv==1.0.1 +sseclient-py==1.8.0 \ No newline at end of file diff --git a/src/client.py b/src/client.py index b4ca923..4f11b89 100644 --- a/src/client.py +++ b/src/client.py @@ -1,15 +1,17 @@ -import json import logging import os from dataclasses import dataclass, field from enum import Enum -from typing import Any +from threading import Event, Thread +from typing import Any, Callable from urllib.parse import urljoin from uuid import UUID +import requests import urllib3 from dotenv import dotenv_values from requests import Session +from sseclient import SSEClient from .exceptions import UnauthorizedException, WhereIsException @@ -55,6 +57,66 @@ class OrderField(Enum): DesDateEnd = "-date_end" +@dataclass(frozen=True, slots=True) +class SessionWatcher: + """ + Handle the SSE client connection in a daemon thread. + + NOTE: Should not be instanciated directly, use + the `Client` instead. + """ + + session_id: UUID + _thread: Thread + _client: SSEClient + _event: Event + + @classmethod + def from_session_id( + cls, + base_url: str, + id_: UUID, + headers: dict[str, Any], + callback: Callable[[str], None] | None = None, + verify: bool = True, + ) -> "SessionWatcher": + session_url = urljoin(base_url, f"/sessions/{id_}/events/") + + headers = {**headers, "Accept": "text/event-stream"} + resp = requests.get(session_url, stream=True, headers=headers, verify=verify) + client = SSEClient(resp) # type: ignore + event = Event() + + def _job(): + logging.debug(f"SSE client daemon started for session: {id_}") + + for evt in client.events(): + logging.debug(f"event received for session: {id_}") + if callback is not None: + callback(evt.data) + + if event.is_set(): + logging.debug(f"SSE client daemon stopped for session: {id_}") + break + + t = Thread(target=_job, daemon=True) + t.start() + + return SessionWatcher(id_, t, client, event) + + def stop(self, force: bool = False): + """ + Send an event to stop the events stream and wait + for the thread to finish. + + If you want to stop the stream savagely, you can use + the optional arg: `force`. + """ + self._event.set() + self._thread.join(timeout=0 if force else None) + logging.debug(f"SSE stream client for session: {self.session_id} closed") + + @dataclass(slots=True) class Client: """ @@ -72,6 +134,10 @@ class Client: session: Session = field(init=False) + sessions_watcher: dict[UUID, SessionWatcher] = field( + default_factory=dict, init=False + ) + def _login(self) -> WhereIsException | None: """Get the access token and store it in the `Session` header""" data = { @@ -137,7 +203,7 @@ class Client: cli._login() - logging.info(f"client successfully initialized for session: {cli.email}") + logging.info(f"client successfully initialized for user: {cli.email}") return cli @refresh() @@ -148,7 +214,6 @@ class Client: raise UnauthorizedException() if res.status_code >= 400: - print(json.dumps(res.json(), indent=2)) if res.status_code >= 400: raise WhereIsException(url, res) @@ -319,3 +384,56 @@ class Client: raise WhereIsException(session_url, res) return res.json() + + def watch_session_events( + self, id_: UUID, callback: Callable[[str], None] | None = None + ): + """ + Watch session events through an SSE client. + + It will launch a daemon thread, listening for incoming events. + You can use the `callback` optional argument to pass a callable + to deal with the events. + + Example: + -------- + def treat_events(evt: str): + # your instructions + print(evt) + + cli.watch_session_events("session-id", treat_events) + + NOTE: You have to manually manage the connection error (IO, Authentication, etc...) + For authentication error, you'll receive this kind of event: + { + "condition": "forbidden", + "text": "Permission denied to channels: session_session-id", + "channels": ["session_session-id"] + } + """ # noqa + if self.sessions_watcher.get(id_) is not None: + logging.warning(f"you're already watching session events: {id_}") + return + + sw = SessionWatcher.from_session_id( + self.base_url, + id_, + self.session.headers, # type: ignore + callback, + self.session.verify, # type: ignore + ) + + logging.info(f"session events (id: {id_}) watcher started") + self.sessions_watcher[id_] = sw + + def stop_watch_session(self, id_: UUID, force: bool = False): + """ + Stop watching events for a session. + + Use `force` optional argument to kill the watcher + instead of waiting for a graceful stop. + """ + if (sw := self.sessions_watcher.get(id_)) is not None: + sw.stop(force) + del self.sessions_watcher[id_] + logging.info(f"session events (id: {id_}) watcher stopped")