如何在 Python 使用 multithread 下載檔案

最近在工作上遇到需要從 JFrog 搬移多個檔案的狀況,原本 single thread 且 consequence 依序下載的架構,隨著檔案數量的增多,耗費的時間太多,已經不敷使用,因此來修改一下吧

原本的寫法:

from artifactory import ArtifactorySaaSPath

for obj in download_obj_list:
    obj_target = os.path.join(output_build_dir, str(obj).split('/')[-1])
    logger.info('Downloading ' + obj + ' to ' + obj_target + ' ...')
    
    path = ArtifactorySaaSPath(
        obj, auth=(self.JFROG_USER, self.JFROG_API_KEY))
    with open(obj_target, "wb") as out:
        path.writeto(out, chunk_size=256, progress_func=None)

主要的功能有:

  1. 如果下載失敗可以有重試的機制
  2. 可以調整開的 thread 數量
  3. 針對單一下載事件,如果(重試之後)有錯誤,可以查看 error message

主要思路是:

引入 ArtifactorySaaSPath library 下載 JFrog artifactory 中的檔案

接著我們從 thread.Thread 建立 thread pool ,在 thread 中只有兩種使用的方式,一是 pass a callable object to the constructor,也就是 override __init__() ,二則是 override run() method。

我們此處採用後者,在 run() method 中使用從 thread.Semaphore 建立的 semaphore 鎖,並給予 MAX_THREADS 的 thread 數量,使其利用 acquire() 以及 release(),來管理 atomic counter,原文詳述如下

class threading.Semaphore(value=1)

This class implements semaphore objects. A semaphore manages an atomic counter representing the number of release() calls minus the number of acquire() calls, plus an initial value. The acquire() method blocks if necessary until it can return without making the counter negative. If not given, value defaults to 1.

我們使用 with self.semaphore 的做法來對應 acquire() & release() ,並為每一個取得 semaphore 的 thread 設計 retry 的機制,在 retry 達 MAX_RETRIES 且下載失敗時,將 error message 放進配置好的 error queue 中。

後續可以透過檢查 error queue 的長度即可知道,這次下載有無任意檔案失敗

改用 multithread 的寫法:

class Worker(threading.Thread):
    MAX_THREADS = 3  # Maximum number of threads in the downloading pool
    MAX_RETRIES = 3  # Max number of retries for a single failed download

    # Semaphore to control access to the downloading pool
    semaphore = threading.Semaphore(MAX_THREADS)

    def __init__(self, pkg_name, pkg_path, error_queue, downloader):
        self.pkg_name = pkg_name
        self.pkg_path = pkg_path
        self.error_queue = error_queue
        self.JFROG_USER = downloader.JFROG_USER
        self.JFROG_API_KEY = downloader.JFROG_API_KEY

    def run(self):
        with self.semaphore:      
            try:
                if not self.retry_runner():
                    raise Exception

            except Exception as e:
                err_msg = f"Error downloading {self.pkg_name}: {str(e)}"
                logger.error(err_msg)
                self.error_queue.put(err_msg)
            
    def retry_runner(self):
        import shutil
        _chunk_size = 256
        thread_name = threading.current_thread().name
        logger.info(f"Downloading {self.pkg_name} to {self.pkg_path} ... in thread {thread_name}")
        for attempt in range(self.MAX_RETRIES):
            try:
                path = ArtifactorySaaSPath(
                    self.pkg_name, auth=(self.JFROG_USER, self.JFROG_API_KEY))
                with path.open() as fd, open(self.pkg_path, "wb") as out:
                    shutil.copyfileobj(fd, out, _chunk_size)
                return True
            except Exception as e:
                error_message = f"Error downloading {self.pkg_name} in thread {thread_name}"
                logger.warning(error_message)

                if attempt < self.MAX_RETRIES:
                    logger.info(f"Retrying download for downloading {self.pkg_name} in thread {thread_name}, attempt {attempt + 1}")
                    continue

                logger.error(f"All attempts failed for downloading {self.pkg_name} in thread {thread_name}")
                return False 
        return False

Reference:

https://docs.python.org/3/library/threading.html


留言

發表迴響