~science-computing/butido

e833100d29b774ab30daf18dfb791f6aabf3085c — Matthias Beyer 5 months ago ab25227 + 262793d
Merge branch 'optimize-mass-download'
5 files changed, 269 insertions(+), 170 deletions(-)

A src/commands/source/download.rs
R src/commands/{source.rs => source/mod.rs}
M src/config/not_validated.rs
M src/main.rs
M src/util/progress.rs
A src/commands/source/download.rs => src/commands/source/download.rs +262 -0
@@ 0,0 1,262 @@
//
// Copyright (c) 2020-2021 science+computing ag and other contributors
//
// This program and the accompanying materials are made
// available under the terms of the Eclipse Public License 2.0
// which is available at https://www.eclipse.org/legal/epl-2.0/
//
// SPDX-License-Identifier: EPL-2.0
//

use std::convert::TryFrom;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;

use anyhow::Context;
use anyhow::Error;
use anyhow::Result;
use anyhow::anyhow;
use clap::ArgMatches;
use log::{debug, trace};
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

use crate::config::*;
use crate::package::PackageName;
use crate::package::PackageVersionConstraint;
use crate::repository::Repository;
use crate::source::*;
use crate::util::progress::ProgressBars;

const NUMBER_OF_MAX_CONCURRENT_DOWNLOADS: usize = 100;

/// A wrapper around the indicatif::ProgressBar
///
/// A wrapper around the indicatif::ProgressBar that is used to synchronize status information from
/// the individual download jobs to the progress bar that is used to display download progress to
/// the user.
///
/// The problem this helper solves is that we only have one status bar for all downloads, and all
/// download tasks must be able to increase the number of bytes received, for example, (that is
/// displayed in the status message) but in a sync way.
#[derive(Clone)]
struct ProgressWrapper {
    download_count: u64,
    finished_downloads: u64,
    current_bytes: usize,
    sum_bytes: u64,
    bar: Arc<Mutex<indicatif::ProgressBar>>,
}

impl ProgressWrapper {
    fn new(bar: indicatif::ProgressBar) -> Self {
        Self {
            download_count: 0,
            finished_downloads: 0,
            current_bytes: 0,
            sum_bytes: 0,
            bar: Arc::new(Mutex::new(bar))
        }
    }

    async fn inc_download_count(&mut self) {
        self.download_count += 1;
        self.set_message().await;
        let bar = self.bar.lock().await;
        bar.set_length(bar.length() + 1);
    }

    async fn inc_download_bytes(&mut self, bytes: u64) {
        self.sum_bytes += bytes;
        self.set_message().await;
    }

    async fn finish_one_download(&mut self) {
        self.finished_downloads += 1;
        self.bar.lock().await.inc(1);
        self.set_message().await;
    }

    async fn add_bytes(&mut self, len: usize) {
        self.current_bytes += len;
        self.set_message().await;
    }

    async fn set_message(&self) {
        let bar = self.bar.lock().await;
        bar.set_message(format!("Downloading ({current_bytes}/{sum_bytes} bytes, {dlfinished}/{dlsum} downloads finished)",
                current_bytes = self.current_bytes,
                sum_bytes = self.sum_bytes,
                dlfinished = self.finished_downloads,
                dlsum = self.download_count));
    }

    async fn success(&self) {
        let bar = self.bar.lock().await;
        bar.finish_with_message(format!("Succeeded {}/{} downloads", self.finished_downloads, self.download_count));
    }

    async fn error(&self) {
        let bar = self.bar.lock().await;
        bar.finish_with_message(format!("At least one download of {} failed", self.download_count));
    }
}

async fn perform_download(source: &SourceEntry, progress: Arc<Mutex<ProgressWrapper>>, timeout: Option<u64>) -> Result<()> {
    trace!("Creating: {:?}", source);
    let file = source.create().await.with_context(|| {
        anyhow!(
            "Creating source file destination: {}",
            source.path().display()
        )
    })?;

    let mut file = tokio::io::BufWriter::new(file);
    let client_builder = reqwest::Client::builder()
        .redirect(reqwest::redirect::Policy::limited(10));

    let client_builder = if let Some(to) = timeout {
        client_builder.timeout(std::time::Duration::from_secs(to))
    } else {
        client_builder
    };

    let client = client_builder.build().context("Building HTTP client failed")?;

    let request = client.get(source.url().as_ref())
        .build()
        .with_context(|| anyhow!("Building request for {} failed", source.url().as_ref()))?;

    let response = match client.execute(request).await {
        Ok(resp) => resp,
        Err(e) => {
            return Err(e).with_context(|| anyhow!("Downloading '{}'", source.url()))
        }
    };

    progress.lock()
        .await
        .inc_download_bytes(response.content_length().unwrap_or(0))
        .await;

    let mut stream = response.bytes_stream();
    while let Some(bytes) = stream.next().await {
        let bytes = bytes?;
        tokio::try_join!(
            file.write_all(bytes.as_ref()),
            async {
                progress.lock()
                    .await
                    .add_bytes(bytes.len())
                    .await;
                Ok(())
            }
        )?;
    }

    file.flush()
        .await
        .map_err(Error::from)
        .map(|_| ())
}


// Implementation of the 'source download' subcommand
pub async fn download(
    matches: &ArgMatches,
    config: &Configuration,
    repo: Repository,
    progressbars: ProgressBars,
) -> Result<()> {
    let force = matches.is_present("force");
    let timeout = matches.value_of("timeout")
        .map(u64::from_str)
        .transpose()
        .context("Parsing timeout argument to integer")?;
    let cache = PathBuf::from(config.source_cache_root());
    let sc = SourceCache::new(cache);
    let pname = matches
        .value_of("package_name")
        .map(String::from)
        .map(PackageName::from);
    let pvers = matches
        .value_of("package_version")
        .map(PackageVersionConstraint::try_from)
        .transpose()?;

    let matching_regexp = matches.value_of("matching")
        .map(crate::commands::util::mk_package_name_regex)
        .transpose()?;

    let progressbar = Arc::new(Mutex::new(ProgressWrapper::new(progressbars.bar())));

    let download_sema = Arc::new(tokio::sync::Semaphore::new(NUMBER_OF_MAX_CONCURRENT_DOWNLOADS));

    let r = repo.packages()
        .filter(|p| {
            match (pname.as_ref(), pvers.as_ref(), matching_regexp.as_ref()) {
                (None, None, None)              => true,
                (Some(pname), None, None)       => p.name() == pname,
                (Some(pname), Some(vers), None) => p.name() == pname && vers.matches(p.version()),
                (None, None, Some(regex))       => regex.is_match(p.name()),

                (_, _, _) => {
                    panic!("This should not be possible, either we select packages by name and (optionally) version, or by regex.")
                },
            }
        })
        .map(|p| {
            sc.sources_for(p).into_iter().map(|source| {
                let download_sema = download_sema.clone();
                let progressbar = progressbar.clone();
                async move {
                    let source_path_exists = source.path().exists();
                    if !source_path_exists && source.download_manually() {
                        return Err(anyhow!(
                            "Cannot download source that is marked for manual download"
                        ))
                        .context(anyhow!("Creating source: {}", source.path().display()))
                        .context(anyhow!("Downloading source: {}", source.url()))
                        .map_err(Error::from);
                    }

                    if source_path_exists && !force {
                        Err(anyhow!("Source exists: {}", source.path().display()))
                    } else {
                        if source_path_exists /* && force is implied by 'if' above*/ {
                            if let Err(e) = source.remove_file().await {
                                return Err(e)
                            }
                        }

                        progressbar.lock().await.inc_download_count().await;
                        {
                            let permit = download_sema.acquire_owned().await?;
                            perform_download(&source, progressbar.clone(), timeout).await?;
                            drop(permit);
                        }
                        progressbar.lock().await.finish_one_download().await;
                        Ok(())
                    }
                }
            })
        })
        .flatten()
        .collect::<futures::stream::FuturesUnordered<_>>()
        .collect::<Vec<Result<()>>>()
        .await
        .into_iter()
        .collect::<Result<()>>();

    if r.is_err() {
        progressbar.lock().await.error().await;
    } else {
        progressbar.lock().await.success().await;
    }

    debug!("r = {:?}", r);
    r
}


R src/commands/source.rs => src/commands/source/mod.rs +5 -156
@@ 10,19 10,17 @@

//! Implementation of the 'source' subcommand

use std::convert::TryFrom;
use std::io::Write;
use std::path::PathBuf;
use std::convert::TryFrom;
use std::str::FromStr;

use anyhow::anyhow;
use anyhow::Context;
use anyhow::Error;
use anyhow::Result;
use anyhow::anyhow;
use clap::ArgMatches;
use colored::Colorize;
use log::{info, trace};
use tokio::io::AsyncWriteExt;
use tokio_stream::StreamExt;

use crate::config::*;


@@ 33,6 31,8 @@ use crate::repository::Repository;
use crate::source::*;
use crate::util::progress::ProgressBars;

mod download;

/// Implementation of the "source" subcommand
pub async fn source(
    matches: &ArgMatches,


@@ 44,7 44,7 @@ pub async fn source(
        Some(("verify", matches)) => verify(matches, config, repo, progressbars).await,
        Some(("list-missing", matches)) => list_missing(matches, config, repo).await,
        Some(("url", matches)) => url(matches, repo).await,
        Some(("download", matches)) => download(matches, config, repo, progressbars).await,
        Some(("download", matches)) => crate::commands::source::download::download(matches, config, repo, progressbars).await,
        Some(("of", matches)) => of(matches, config, repo).await,
        Some((other, _)) => return Err(anyhow!("Unknown subcommand: {}", other)),
        None => Err(anyhow!("No subcommand")),


@@ 218,157 218,6 @@ pub async fn url(matches: &ArgMatches, repo: Repository) -> Result<()> {
        })
}

pub async fn download(
    matches: &ArgMatches,
    config: &Configuration,
    repo: Repository,
    progressbars: ProgressBars,
) -> Result<()> {
    async fn perform_download(source: &SourceEntry, bar: &indicatif::ProgressBar, timeout: Option<u64>) -> Result<()> {
        trace!("Creating: {:?}", source);
        let file = source.create().await.with_context(|| {
            anyhow!(
                "Creating source file destination: {}",
                source.path().display()
            )
        })?;

        let mut file = tokio::io::BufWriter::new(file);
        let client_builder = reqwest::Client::builder()
            .redirect(reqwest::redirect::Policy::limited(10));

        let client_builder = if let Some(to) = timeout {
            client_builder.timeout(std::time::Duration::from_secs(to))
        } else {
            client_builder
        };

        let client = client_builder.build().context("Building HTTP client failed")?;

        let request = client.get(source.url().as_ref())
            .build()
            .with_context(|| anyhow!("Building request for {} failed", source.url().as_ref()))?;

        let response = match client.execute(request).await {
            Ok(resp) => resp,
            Err(e) => {
                bar.finish_with_message(format!("Failed: {}", source.url()));
                return Err(e).with_context(|| anyhow!("Downloading '{}'", source.url()))
            }
        };

        if let Some(len) = response.content_length() {
            bar.set_length(len);
        }

        let mut stream = reqwest::get(source.url().as_ref()).await?.bytes_stream();
        let mut bytes_written = 0;
        while let Some(bytes) = stream.next().await {
            let bytes = bytes?;
            file.write_all(bytes.as_ref()).await?;
            bytes_written += bytes.len();

            bar.inc(bytes.len() as u64);
            if let Some(len) = response.content_length() {
                bar.set_message(format!("Downloading {} ({}/{} bytes)", source.url(), bytes_written, len));
            } else {
                bar.set_message(format!("Downloading {} ({} bytes)", source.url(), bytes_written));
            }
        }

        file.flush()
            .await
            .map_err(Error::from)
            .map(|_| ())
    }

    let force = matches.is_present("force");
    let timeout = matches.value_of("timeout")
        .map(u64::from_str)
        .transpose()
        .context("Parsing timeout argument to integer")?;
    let cache = PathBuf::from(config.source_cache_root());
    let sc = SourceCache::new(cache);
    let pname = matches
        .value_of("package_name")
        .map(String::from)
        .map(PackageName::from);
    let pvers = matches
        .value_of("package_version")
        .map(PackageVersionConstraint::try_from)
        .transpose()?;
    let multi = {
        let mp = indicatif::MultiProgress::new();
        if progressbars.hide() {
            mp.set_draw_target(indicatif::ProgressDrawTarget::hidden());
        }
        mp
    };

    let matching_regexp = matches.value_of("matching")
        .map(crate::commands::util::mk_package_name_regex)
        .transpose()?;

    let r = repo
        .packages()
        .filter(|p| {
            match (pname.as_ref(), pvers.as_ref(), matching_regexp.as_ref()) {
                (None, None, None)              => true,
                (Some(pname), None, None)       => p.name() == pname,
                (Some(pname), Some(vers), None) => p.name() == pname && vers.matches(p.version()),
                (None, None, Some(regex))       => regex.is_match(p.name()),

                (_, _, _) => {
                    panic!("This should not be possible, either we select packages by name and (optionally) version, or by regex.")
                },
            }
        })
        .map(|p| {
            sc.sources_for(p).into_iter().map(|source| {
                let bar = multi.add(progressbars.spinner());
                bar.set_message(format!("Downloading {}", source.url()));
                async move {
                    let source_path_exists = source.path().exists();
                    if !source_path_exists && source.download_manually() {
                        return Err(anyhow!(
                            "Cannot download source that is marked for manual download"
                        ))
                        .context(anyhow!("Creating source: {}", source.path().display()))
                        .context(anyhow!("Downloading source: {}", source.url()))
                        .map_err(Error::from);
                    }

                    if source_path_exists && !force {
                        Err(anyhow!("Source exists: {}", source.path().display()))
                    } else {
                        if source_path_exists /* && force is implied by 'if' above*/ {
                            if let Err(e) = source.remove_file().await {
                                bar.finish_with_message(format!("Failed to remove existing file: {}", source.path().display()));
                                return Err(e)
                            }
                        }


                        if let Err(e) = perform_download(&source, &bar, timeout).await {
                            bar.finish_with_message(format!("Failed: {}", source.url()));
                            Err(e)
                        } else {
                            bar.finish_with_message(format!("Finished: {}", source.url()));
                            Ok(())
                        }
                    }
                }
            })
        })
        .flatten()
        .collect::<futures::stream::FuturesUnordered<_>>()
        .collect::<Vec<Result<()>>>();

    let multibar_block = tokio::task::spawn_blocking(move || multi.join());
    let (r, _) = tokio::join!(r, multibar_block);
    r.into_iter().collect()
}

async fn of(
    matches: &ArgMatches,
    config: &Configuration,

M src/config/not_validated.rs => src/config/not_validated.rs +1 -0
@@ 51,6 51,7 @@ pub struct NotValidatedConfiguration {
    /// The format of the spinners in the CLI
    #[serde(default = "default_spinner_format")]
    #[getset(get = "pub")]
    #[allow(unused)]
    spinner_format: String,

    /// The format used to print a package

M src/main.rs => src/main.rs +0 -1
@@ 139,7 139,6 @@ async fn main() -> Result<()> {
    let hide_bars = cli.is_present("hide_bars") || crate::util::stdout_is_pipe();
    let progressbars = ProgressBars::setup(
        config.progress_format().clone(),
        config.spinner_format().clone(),
        hide_bars,
    );


M src/util/progress.rs => src/util/progress.rs +1 -13
@@ 14,17 14,15 @@ use getset::CopyGetters;
#[derive(Clone, Debug, CopyGetters)]
pub struct ProgressBars {
    bar_template: String,
    spinner_template: String,

    #[getset(get_copy = "pub")]
    hide: bool,
}

impl ProgressBars {
    pub fn setup(bar_template: String, spinner_template: String, hide: bool) -> Self {
    pub fn setup(bar_template: String, hide: bool) -> Self {
        ProgressBars {
            bar_template,
            spinner_template,
            hide,
        }
    }


@@ 38,14 36,4 @@ impl ProgressBars {
            b
        }
    }

    pub fn spinner(&self) -> ProgressBar {
        if self.hide {
            ProgressBar::hidden()
        } else {
            let bar = ProgressBar::new_spinner();
            bar.set_style(ProgressStyle::default_spinner().template(&self.spinner_template));
            bar
        }
    }
}