import hashlib
import os
import pathlib
import sys
import urllib
import urllib.request
from typing import Any, Optional, Union
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TimeElapsedColumn,
TransferSpeedColumn,
)
USER_AGENT = "afnio"
def _urlretrieve(
url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32
) -> None:
console = Console(force_jupyter=False)
with urllib.request.urlopen(
urllib.request.Request(url, headers={"User-Agent": USER_AGENT})
) as response:
total = response.length
with (
open(filename, "wb") as fh,
Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
DownloadColumn(),
TransferSpeedColumn(),
TimeElapsedColumn(),
TextColumn(
"\033[K\033[K\033[K"
), # This workarounds a Rich bug that leaves ghost text
auto_refresh=False, # We refresh manually to avoid the Rich bug
transient=False,
console=console,
) as progress,
):
task = progress.add_task("Downloading", total=total, visible=True)
while chunk := response.read(chunk_size):
fh.write(chunk)
progress.update(task, advance=len(chunk))
progress.refresh()
[docs]
def calculate_md5(
fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024
) -> str:
# Setting the `usedforsecurity` flag does not change anything about the
# functionality, but indicates that we are not using the MD5 checksum for
# cryptography. This enables its usage in restricted environments like FIPS.
if sys.version_info >= (3, 9):
md5 = hashlib.md5(usedforsecurity=False)
else:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
while chunk := f.read(chunk_size):
md5.update(chunk)
return md5.hexdigest()
[docs]
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)
[docs]
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
[docs]
def download_url(
url: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the
basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.fspath(os.path.join(root, filename))
os.makedirs(root, exist_ok=True)
# check if file is already present locally
if check_integrity(fpath, md5):
print(f"Using downloaded and verified file: {fpath}")
return
# download the file
try:
print(f"Downloading {url} to {fpath}")
_urlretrieve(url, fpath)
except (urllib.error.URLError, OSError) as e:
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
f"Failed download. Trying https -> http instead. "
f"Downloading {url} to {fpath}"
)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
[docs]
def download(
url: str,
download_root: Union[str, pathlib.Path],
extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
# TODO: Handle unpacking of the archive if needed
# archive = os.path.join(download_root, filename)
# print(f"Extracting {archive} to {extract_root}")