import argparse import json import logging import matplotlib.pyplot as plt import numpy as np import os from numpy.core.fromnumeric import var import requests from cycler import cycler from datetime import datetime as dt from typing import List, Optional, Tuple FORMAT = "%(asctime)s - %(levelname)s - %(message)s" logging.basicConfig(format=FORMAT, level=logging.INFO) DATA_URL = "https://covid.ourworldindata.org/data/owid-covid-data.{extension}" DATA_REPOSITORY = "data" OUTPUT_REPOSITORY = "output" class DataProvider: """ src data : https://github.com/owid/covid-19-data/tree/master/public/data """ def __init__(self, country_key: Optional[str] = "FRA", refresh=False): self._dic_data = None if len(country_key) != 3: raise ValueError( "country key provided : '{}' is invalid. It must be a trigram like : 'FRA'".format( country_key ) ) self._country_key = country_key.upper() self._get_data(refresh=refresh) def _get_var_by_date( self, var_name: str, ) -> Tuple[List[dt], np.array]: if self._country_key not in self._dic_data.keys(): logging.info("country keys : {}".format(self._dic_data.keys())) raise IndexError( "'{}' country trigram key does not exist in the data".format( self._country_key ) ) data = self._dic_data[self._country_key]["data"] lst_date, lst_var = [], [] for dic in data: date = dt.strptime(dic["date"], "%Y-%m-%d") lst_date.append(date) if var_name not in dic: lst_var.append(None) continue lst_var.append(dic[var_name]) if not any(lst_var): logging.warning("no data for variable : '{}'".format(var_name)) return lst_date, np.asarray(lst_var, np.float64) def _get_data( self, file_path: Optional[str] = None, extension: Optional[str] = "json", refresh=False, ) -> None: os.makedirs(DATA_REPOSITORY, exist_ok=True) data_url = DATA_URL.format(extension=extension) if data_url.endswith("/"): data_url = data_url[:-1] file_path = ( os.path.join(DATA_REPOSITORY, data_url.split("/")[-1]) if file_path is None else file_path ) if not os.path.isfile(file_path) or refresh: r = requests.get(data_url) if not r.content: raise ValueError("no data provided froim the url : {}".format(data_url)) with open(file_path, "wb") as f: f.write(r.content) self._dic_data = json.loads(r.content) return self._dic_data = json.load(open(file_path, "rb")) def plot(self, lst_var_name: List[str], output_path: Optional[str] = None) -> None: if not lst_var_name: logging.error("no var name list provided to fetch data") return os.makedirs(OUTPUT_REPOSITORY, exist_ok=True) if output_path is None: output_path = os.path.join( OUTPUT_REPOSITORY, "{}-{}.png".format( self._country_key.lower(), dt.strftime(dt.now(), "%Y%m%d") ), ) fig, ax = plt.subplots( figsize=(10 if len(lst_var_name) < 3 else 10 + 1.2 * (len(lst_var_name)), 7) ) fig.subplots_adjust(right=0.7) colors = plt.rcParams["axes.prop_cycle"]() for idx, var_name in enumerate(lst_var_name): lst_date, np_var = self._get_var_by_date(var_name) color = next(colors)["color"] if not idx: ax.plot( lst_date, np_var, ".-", label="{}".format(var_name), color=color, markersize=4, ) ax.yaxis.label.set_color(color) ax.tick_params(axis="y", colors=color) ax.set_ylabel(var_name) continue sub_ax = ax.twinx() if idx > 1: shift = 1.15 + (0.15 * (idx - 2)) sub_ax.spines.right.set_position(("axes", shift)) sub_ax.plot( lst_date, np_var, ".-", label="{}".format(var_name), color=color, markersize=4, ) sub_ax.yaxis.label.set_color(color) sub_ax.tick_params(axis="y", colors=color) sub_ax.set_ylabel(var_name) ax.tick_params(axis="x") ax.set_xlabel("date") ax.grid(True) fig.autofmt_xdate() plt.title(self._country_key) plt.savefig(output_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("country", type=str, help="country's trigram") parser.add_argument("variables", nargs="+", help="variable list to fetch data") parser.add_argument( "-r", "--refresh", action="store_true", default=False, help="redownload data for updates", ) args = parser.parse_args() data_provider = DataProvider(args.country, refresh=args.refresh) data_provider.plot(args.variables)