covid-plotter/drees.py

357 lines
11 KiB
Python

import argparse
import json
import logging
import os
from datetime import datetime as dt
from enum import Enum
from functools import partial
from multiprocessing import Pool
from typing import Any, Dict, List, Optional, OrderedDict, Tuple
import numpy as np
import pandas as pd
import requests
from matplotlib import dates as md
from matplotlib import pyplot as plt
from numba import njit
FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
DATE_FORMAT = "%Y-%m-%d"
DATA_URL = "https://data.drees.solidarites-sante.gouv.fr/api/records/1.0/search/?dataset=covid-19-resultats-par-age-issus-des-appariements-entre-si-vic-si-dep-et-vac-si&q=&rows=-1&facet=date&facet=vac_statut&facet=age"
DATA_REPOSITORY = "data"
OUTPUT_REPOSITORY = "output"
class Field(str, Enum):
HC = "hc"
SC = "sc"
DC = "dc"
class VacStatus(str, Enum):
NC = "Non-vaccinés"
PDR = "Primo dose récente"
PDE = "Primo dose efficace"
CM3MSR = "Complet de moins de 3 mois - sans rappel"
CM3MAR = "Complet de moins de 3 mois - avec rappel"
CM36MSR = "Complet entre 3 mois et 6 mois - sans rappel"
CM36MAR = "Complet entre 3 mois et 6 mois - avec rappel"
class AgeGroup(str, Enum):
VERY_YOUNG = "[0,19]"
YOUNG = "[20,39]"
MID_OLD = "[40,59]"
OLD = "[60,79]"
VERY_OLD = "[80;+]"
def get_data(
file_path: Optional[str] = None,
extension: Optional[str] = "json",
refresh=False,
) -> Dict[str, Any]:
"""
collect covid data by age from DREES
"""
os.makedirs(DATA_REPOSITORY, exist_ok=True)
data_url = DATA_URL.format(extension=extension)
if data_url.endswith("/"):
data_url = data_url[:-1]
file_path = (
os.path.join(DATA_REPOSITORY, data_url.split("/")[-1])
if file_path is None
else file_path
)
if not os.path.isfile(file_path) or refresh:
r = requests.get(data_url)
if not r.content:
raise ValueError("no data provided froim the url : {}".format(data_url))
with open(file_path, "wb") as f:
f.write(r.content)
return json.loads(r.content)
return json.load(open(file_path, "rb"))
def group_by_age_date(data: Dict[str, Any], fields: List[str]) -> Dict[dt, Any]:
"""
group the original dictionnary into a more readable one
'date': {
'age' : {
'vac_status' : {
'hc',
'sc',
'dc',
...
}
}
}
"""
dic_data_grouped: Dict[dt, Any] = OrderedDict()
for row in data["records"]:
row_fields = row["fields"]
date = dt.strptime(row_fields["date"], DATE_FORMAT)
age = row_fields["age"]
vac_status = row_fields["vac_statut"]
if date not in dic_data_grouped:
dic_data_grouped[date] = OrderedDict()
if age not in dic_data_grouped[date]:
dic_data_grouped[date][age] = OrderedDict()
if vac_status not in dic_data_grouped[date][age]:
dic_data_grouped[date][age][vac_status] = OrderedDict()
for field in fields:
dic_data_grouped[date][age][vac_status][field] = row_fields[field]
return dic_data_grouped
@njit
def cumulate_array(array: np.ndarray) -> np.ndarray:
cumulate = list()
sum: float = 0
for item in array:
sum += item
cumulate.append(sum)
return np.array(cumulate)
def get_plot_fig(
grid: Optional[bool] = True, date_format: Optional[str] = DATE_FORMAT
) -> plt.figure:
"""
return pyplot fig, ax to plot data over range period with date formatting
"""
fig, ax = plt.subplots()
ax.grid(grid)
date_formatter = md.DateFormatter(date_format)
ax.xaxis.set_major_locator(md.AutoDateLocator())
ax.xaxis.set_major_formatter(date_formatter)
fig.autofmt_xdate()
return fig
def save_and_close_fig(
fig: plt.figure, output_path: str, has_legend: Optional[bool] = True
):
if has_legend:
plt.legend()
plt.savefig(output_path)
plt.close(fig)
def get_cumulative_field_by_age(
dic_data_grouped: Dict[dt, Any], age: str, field: Field
) -> Tuple[np.ndarray, List[dt]]:
"""
cumulate field values over data period
"""
dcs: List[int] = list()
dates: List[dt] = list()
for date, dic_age_grouped in dic_data_grouped.items():
if (dic_age := dic_age_grouped.get(age)) is None:
logging.error(f"{age} not found in grouped ages")
continue
for dic_vac_status in dic_age.values():
if (field_value := dic_vac_status[field.value]) is not None:
dcs.append(field_value)
dates.append(date)
np_dcs = np.array(dcs)
np_cumulate = cumulate_array(np_dcs)
return np_cumulate, dates
def get_values_by_age_vac_field(
dic_data_grouped: Dict[dt, Any], age: AgeGroup, vac_status: VacStatus, field: Field
) -> Tuple[List[dt], List[float]]:
"""
get deep field data by age, vaccine status and field
"""
dates: List[dt] = list()
fields: List[float] = list()
for date, dic_age_grouped in dic_data_grouped.items():
if (dic_vac_status := dic_age_grouped.get(age.value)) is not None:
if (dic_field := dic_vac_status.get(vac_status.value)) is not None:
if (field_value := dic_field.get(field.value)) is not None:
fields.append(field_value)
dates.append(date)
return dates, fields
def get_values_by_age_vac(
dic_data_grouped: Dict[dt, Any], age: AgeGroup, vac_status: VacStatus
) -> Tuple[List[dt], List[Dict[str, Any]]]:
"""
get deep fields data by age and vaccine status
"""
dates: List[dt] = list()
fields: List[Dict[str, Any]] = list()
for date, dic_age_grouped in dic_data_grouped.items():
if (dic_vac_status := dic_age_grouped.get(age.value)) is not None:
if (dic_field := dic_vac_status.get(vac_status.value)) is not None:
fields.append(dic_field)
dates.append(date)
return dates, fields
def plot_cumulative_field(dic_data_grouped: Dict[dt, Any], field: Field) -> None:
fig = get_plot_fig()
for age_group in AgeGroup:
deaths, dates = get_cumulative_field_by_age(
dic_data_grouped, age_group.value, field
)
plt.plot(dates, deaths, label=age_group.value)
plt.title(
f"nombre de {field.value} cumulé par age (status vaccinal non pris en compte)"
)
plt.xlabel("date")
save_and_close_fig(
fig, os.path.join(OUTPUT_REPOSITORY, f"cumulative_{field.value}.pdf")
)
def extract_field_values(fields: List[Dict[str, Any]], field: Field) -> np.ndarray:
field_values: List[float] = list()
for item in fields:
if (value := item.get(field)) is not None:
field_values.append(value)
return np.asarray(field_values)
def plot_data_by_age_vac(
dic_data_grouped: Dict[dt, Any], age: AgeGroup, vac_status: VacStatus
) -> None:
"""
plot data by vaccine status, age and field
"""
fig = get_plot_fig()
dates, fields = get_values_by_age_vac(dic_data_grouped, age, vac_status)
for field in Field:
field_values = extract_field_values(fields, field)
plt.plot(dates, field_values, label=f"{field.value}")
plt.xlabel("date")
plt.ylabel("nombre")
plt.title(f"{age}ans - {vac_status}")
save_and_close_fig(fig, os.path.join(OUTPUT_REPOSITORY, f"{age}_{vac_status}.pdf"))
def group_by_date_age_vac(
dic_data_grouped: Dict[dt, Any],
field: Field,
is_vac: Optional[bool] = True,
limit_days: Optional[int] = 30,
) -> Dict[str, Any]:
dic_data: Dict[str, Any] = OrderedDict()
for date, dic_age in dic_data_grouped.items():
if abs((date - dt.now())).days >= limit_days:
continue
date_format = date.strftime(DATE_FORMAT)
dic_data[date_format] = OrderedDict()
for age, dic_vac in dic_age.items():
nb_vac, nb_unvac = 0, 0
for vac_status, dic_field in dic_vac.items():
if vac_status == VacStatus.NC.value:
nb_unvac += dic_field.get(field.value, 0)
continue
nb_vac += dic_field.get(field.value, 0)
sum_vac = nb_vac + nb_unvac
try:
percent_vac = (nb_vac / sum_vac) * 100
except ZeroDivisionError:
percent_vac = 0
try:
percent_unvac = (nb_unvac / sum_vac) * 100
except ZeroDivisionError:
percent_unvac = 0
dic_data[date_format][age] = percent_vac if is_vac else percent_unvac
return dic_data
def plot_bar_data_by_field(
dic_data_grouped: Dict[dt, Any], field: Field, is_vac: Optional[bool] = True
) -> None:
"""
display a bar graph by field grouped by age over the data period
bars displays vaccine status percent
"""
plt.rcParams["font.size"] = "24"
dic_data = group_by_date_age_vac(dic_data_grouped, field, is_vac=is_vac)
df = pd.DataFrame(dic_data).T
ax = df.plot.bar(figsize=(26, 15))
ax.set_title(f"{field.value} vaccinate percent grouped by age")
ax.set_xlabel("date")
fig = ax.get_figure()
plt.xticks(rotation=45)
plt.legend(loc="upper right")
plt.tight_layout()
filename = "vac" if is_vac else "unvac"
fig.savefig(
os.path.join(OUTPUT_REPOSITORY, f"{filename}_age_grouped_{field.value}.pdf")
)
def build_data_pool_args() -> List[Tuple[AgeGroup, VacStatus]]:
"""
build tuple arguments to plot all data on multiprocess
"""
pool_args: List[Tuple[AgeGroup, VacStatus]] = list()
for age_group in AgeGroup:
for vac_status in VacStatus:
pool_args.append((age_group, vac_status))
return pool_args
if __name__ == "__main__":
"""
This script aims to plot DRESS data
Plots availables :
- cumulative deaths by age
- indicators by vaccine status and age
- indicators vaccine/unvaccine percent grouped by age
Main indicators are :
- hospitalisations
- criticals
- deaths
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--refresh",
action="store_true",
default=False,
help="redownload data for updates",
)
args = parser.parse_args()
os.makedirs(OUTPUT_REPOSITORY, exist_ok=True)
dic_data: Dict[str, Any] = get_data(
file_path=os.path.join(DATA_REPOSITORY, "dress.json"), refresh=args.refresh
)
dic_data_grouped: Dict[dt, Any] = group_by_age_date(
dic_data, [x.value for x in Field]
)
plot_data_pool_args = build_data_pool_args()
f = partial(plot_data_by_age_vac, dic_data_grouped)
with Pool() as pool:
pool.starmap(f, plot_data_pool_args)
for field in Field:
plot_cumulative_field(dic_data_grouped, field)
for field in Field:
plot_bar_data_by_field(dic_data_grouped, field)
plot_bar_data_by_field(dic_data_grouped, field, is_vac=False)