160 lines
5.3 KiB
Python

import argparse
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
from numpy.core.fromnumeric import var
import requests
from cycler import cycler
from datetime import datetime as dt
from typing import List, Optional, Tuple
FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
DATA_URL = "https://covid.ourworldindata.org/data/owid-covid-data.{extension}"
DATA_REPOSITORY = "data"
OUTPUT_REPOSITORY = "static/plots"
class DataProvider:
"""
src data : https://github.com/owid/covid-19-data/tree/master/public/data
"""
def __init__(self, country_key: Optional[str] = "FRA", refresh=False):
self._dic_data = None
if len(country_key) != 3:
raise ValueError(
"country key provided : '{}' is invalid. It must be a trigram like : 'FRA'".format(
country_key
)
)
self._country_key = country_key.upper()
self._get_data(refresh=refresh)
def _get_var_by_date(
self,
var_name: str,
) -> Tuple[List[dt], np.array]:
if self._country_key not in self._dic_data.keys():
logging.info("country keys : {}".format(self._dic_data.keys()))
raise IndexError(
"'{}' country trigram key does not exist in the data".format(
self._country_key
)
)
data = self._dic_data[self._country_key]["data"]
lst_date, lst_var = [], []
for dic in data:
date = dt.strptime(dic["date"], "%Y-%m-%d")
lst_date.append(date)
if var_name not in dic:
lst_var.append(None)
continue
lst_var.append(dic[var_name])
if not any(lst_var):
logging.warning("no data for variable : '{}'".format(var_name))
return lst_date, np.asarray(lst_var, np.float64)
def _get_data(
self,
file_path: Optional[str] = None,
extension: Optional[str] = "json",
refresh=False,
) -> None:
os.makedirs(DATA_REPOSITORY, exist_ok=True)
data_url = DATA_URL.format(extension=extension)
if data_url.endswith("/"):
data_url = data_url[:-1]
file_path = (
os.path.join(DATA_REPOSITORY, data_url.split("/")[-1])
if file_path is None
else file_path
)
if not os.path.isfile(file_path) or refresh:
r = requests.get(data_url)
if not r.content:
raise ValueError("no data provided froim the url : {}".format(data_url))
with open(file_path, "wb") as f:
f.write(r.content)
self._dic_data = json.loads(r.content)
return
self._dic_data = json.load(open(file_path, "rb"))
def plot(self, lst_var_name: List[str], output_path: Optional[str] = None) -> None:
if not lst_var_name:
logging.error("no var name list provided to fetch data")
return
os.makedirs(OUTPUT_REPOSITORY, exist_ok=True)
if output_path is None:
output_path = os.path.join(
OUTPUT_REPOSITORY,
"{}-{}.png".format(
self._country_key.lower(), dt.strftime(dt.now(), "%Y%m%d")
),
)
fig, ax = plt.subplots(
figsize=(10 if len(lst_var_name) < 3 else 10 + 1.2 * (len(lst_var_name)), 7)
)
fig.subplots_adjust(right=0.7)
colors = plt.rcParams["axes.prop_cycle"]()
for idx, var_name in enumerate(lst_var_name):
lst_date, np_var = self._get_var_by_date(var_name)
color = next(colors)["color"]
if not idx:
ax.plot(
lst_date,
np_var,
".-",
label="{}".format(var_name),
color=color,
markersize=4,
)
ax.yaxis.label.set_color(color)
ax.tick_params(axis="y", colors=color)
ax.set_ylabel(var_name)
continue
sub_ax = ax.twinx()
if idx > 1:
shift = 1.15 + (0.15 * (idx - 2))
sub_ax.spines.right.set_position(("axes", shift))
sub_ax.plot(
lst_date,
np_var,
".-",
label="{}".format(var_name),
color=color,
markersize=4,
)
sub_ax.yaxis.label.set_color(color)
sub_ax.tick_params(axis="y", colors=color)
sub_ax.set_ylabel(var_name)
ax.tick_params(axis="x")
ax.set_xlabel("date")
ax.grid(True)
fig.autofmt_xdate()
plt.title(self._country_key)
plt.savefig(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("country", type=str, help="country's trigram")
parser.add_argument("variables", nargs="+", help="variable list to fetch data")
parser.add_argument(
"-r",
"--refresh",
action="store_true",
default=False,
help="redownload data for updates",
)
args = parser.parse_args()
data_provider = DataProvider(args.country, refresh=args.refresh)
data_provider.plot(args.variables)