plot cumulative indicators + plot data by age and vaccine status (multiprocessed)

This commit is contained in:
rmanach 2021-12-21 23:35:49 +01:00
parent a59bf5c619
commit 08cfc98dec

170
drees.py
View File

@ -2,16 +2,24 @@ import argparse
import json
import logging
import os
from datetime import datetime as dt
from enum import Enum
from typing import Any, Dict, List, Optional, OrderedDict
from functools import partial
from multiprocessing import Pool
from typing import Any, Dict, List, Optional, OrderedDict, Tuple
import matplotlib
import numpy as np
import requests
from matplotlib import dates as md
from matplotlib import pyplot as plt
from numba import njit
FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
DATE_FORMAT = "%Y-%m-%d"
DATA_URL = "https://data.drees.solidarites-sante.gouv.fr/api/records/1.0/search/?dataset=covid-19-resultats-par-age-issus-des-appariements-entre-si-vic-si-dep-et-vac-si&q=&rows=-1&facet=date&facet=vac_statut&facet=age"
DATA_REPOSITORY = "data"
OUTPUT_REPOSITORY = "output"
@ -21,7 +29,6 @@ class Field(str, Enum):
HC = "hc"
SC = "sc"
DC = "dc"
EFF = "effectif"
class VacStatus(str, Enum):
@ -36,7 +43,7 @@ class VacStatus(str, Enum):
class AgeGroup(str, Enum):
VERY_YOUNG = "[0,19]"
YONG = "[20,39]"
YOUNG = "[20,39]"
MID_OLD = "[40,59]"
OLD = "[60,79]"
VERY_OLD = "[80;+]"
@ -69,7 +76,7 @@ def get_data(
return json.load(open(file_path, "rb"))
def group_by_age_date(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]:
def group_by_age_date(data: Dict[str, Any], fields: List[str]) -> Dict[dt, Any]:
"""
Group the original dictionnary into a more readable one
'date': {
@ -83,10 +90,10 @@ def group_by_age_date(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]
}
}
"""
dic_data_grouped: Dict[str, Any] = OrderedDict()
dic_data_grouped: Dict[dt, Any] = OrderedDict()
for row in data["records"]:
row_fields = row["fields"]
date = row_fields["date"]
date = dt.strptime(row_fields["date"], DATE_FORMAT)
age = row_fields["age"]
vac_status = row_fields["vac_statut"]
if date not in dic_data_grouped:
@ -100,37 +107,136 @@ def group_by_age_date(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]
return dic_data_grouped
def plot(dic_data_grouped: Dict[str, Any], age: str, vac_status: str) -> None:
@njit
def cumulate_array(array: np.ndarray) -> np.ndarray:
cumulate = list()
sum: float = 0
for item in array:
sum += item
cumulate.append(sum)
return np.array(cumulate)
def get_plot_fig(
grid: Optional[bool] = True, date_format: Optional[str] = DATE_FORMAT
) -> plt.figure:
"""
Plot data by vaccine status and age
return pyplot fig, ax to plot data over range period with date formatting
"""
x: List[str] = list()
hc: List[float] = list()
sc: List[float] = list()
dc: List[float] = list()
fig, ax = plt.subplots()
ax.grid(grid)
date_formatter = md.DateFormatter(date_format)
ax.xaxis.set_major_locator(md.AutoDateLocator())
ax.xaxis.set_major_formatter(date_formatter)
fig.autofmt_xdate()
return fig
def save_and_close_fig(
fig: plt.figure, output_path: str, has_legend: Optional[bool] = True
):
if has_legend:
plt.legend()
plt.savefig(output_path)
plt.close(fig)
def get_cumulative_field_by_age(
dic_data_grouped: Dict[dt, Any], age: str, field: Field
) -> Tuple[np.ndarray, List[dt]]:
"""
cumulate field values over data period
"""
dcs: List[int] = list()
dates: List[dt] = list()
for date, dic_age_grouped in dic_data_grouped.items():
print(Field.HC.value)
hc.append(dic_age_grouped[age][vac_status][Field.HC.value])
sc.append(dic_age_grouped[age][vac_status][Field.SC.value])
dc.append(dic_age_grouped[age][vac_status][Field.DC.value])
x.append(date)
plt.plot(x, hc, label="hospitalisation")
plt.plot(x, sc, label="soin_critique")
plt.plot(x, dc, label="deces")
if (dic_age := dic_age_grouped.get(age)) is None:
logging.error(f"{age} not found in grouped ages")
continue
for dic_vac_status in dic_age.values():
if (field_value := dic_vac_status[field.value]) is not None:
dcs.append(field_value)
dates.append(date)
np_dcs = np.array(dcs)
np_cumulate = cumulate_array(np_dcs)
return np_cumulate, dates
def get_values_by_age_vac_field(
dic_data_grouped: Dict[dt, Any], age: AgeGroup, vac_status: VacStatus, field: Field
) -> Tuple[List[dt], List[float]]:
"""
get deep field data by age, vaccine status and field
"""
dates: List[dt] = list()
fields: List[float] = list()
for date, dic_age_grouped in dic_data_grouped.items():
if (dic_vac_status := dic_age_grouped.get(age.value)) is not None:
if (dic_field := dic_vac_status.get(vac_status.value)) is not None:
if (field_value := dic_field.get(field.value)) is not None:
fields.append(field_value)
dates.append(date)
return dates, fields
def plot_cumulative_field(dic_data_grouped: Dict[dt, Any], field: Field) -> None:
fig = get_plot_fig()
for age_group in AgeGroup:
deaths, dates = get_cumulative_field_by_age(
dic_data_grouped, age_group.value, field
)
plt.plot(dates, deaths, label=age_group.value)
plt.title(
f"nombre de {field.value} cumulé par age (status vaccinal non pris en compte)"
)
plt.xlabel("date")
save_and_close_fig(
fig, os.path.join(OUTPUT_REPOSITORY, f"cumulative_{field.value}.pdf")
)
def plot_data(
dic_data_grouped: Dict[dt, Any], age: AgeGroup, vac_status: VacStatus, field: Field
) -> None:
"""
Plot data by vaccine status, age and field
"""
fig = get_plot_fig()
dates, fields = get_values_by_age_vac_field(
dic_data_grouped, age, vac_status, field
)
plt.plot(dates, fields, label=f"{field.value}")
plt.xlabel("date")
plt.ylabel("nombre")
plt.title(f"{age}ans - {vac_status}")
ax.grid(True)
ax.xaxis.set_major_locator(md.MonthLocator())
fig.autofmt_xdate()
plt.legend()
plt.savefig(os.path.join(OUTPUT_REPOSITORY, f"{age}_{vac_status}.pdf"))
save_and_close_fig(
fig, os.path.join(OUTPUT_REPOSITORY, f"{age}_{vac_status}_{field}.pdf")
)
def build_data_pool_args() -> List[Tuple[AgeGroup, VacStatus, Field]]:
"""
build tuple arguments to plot all data on multiprocess
"""
pool_args: List[Tuple[AgeGroup, VacStatus, Field]] = list()
for age_group in AgeGroup:
for vac_status in VacStatus:
for field in Field:
pool_args.append((age_group, vac_status, field))
return pool_args
if __name__ == "__main__":
"""
This script aims to plot DRESS data with vaccine status and ages grouped
This script aims to plot DRESS data
Plots availables :
- cumulative deaths by age
- indicators by vaccine status and age
Main indicators are :
- hospitalisations
- criticals
@ -153,8 +259,14 @@ if __name__ == "__main__":
dic_data: Dict[str, Any] = get_data(
file_path=os.path.join(DATA_REPOSITORY, "dress.json"), refresh=args.refresh
)
dic_data_grouped: Dict[str, Any] = group_by_age_date(
dic_data_grouped: Dict[dt, Any] = group_by_age_date(
dic_data, [x.value for x in Field]
)
plot(dic_data_grouped, AgeGroup.YONG, VacStatus.NC.value)
plot(dic_data_grouped, AgeGroup.YONG, VacStatus.CM3MSR.value)
plot_data_pool_args = build_data_pool_args()
f = partial(plot_data, dic_data_grouped)
with Pool() as pool:
pool.starmap(f, plot_data_pool_args)
for field in Field:
plot_cumulative_field(dic_data_grouped, field)