代码片段-CenterNet的训练部分多线程代码分析

发布 : 2019-05-26 浏览 :

CenterNet的代码在train.py文件中自己定义了多线程的数据处理方式,而不是用的pytorch自带的,所以这里来分析一下这部分的实现。

头文件导入多线程处理包

1
2
import queue
from torch.multiprocessing import Process, Queue

queue主要用来创建队列,阻塞型队列,当queue满的时候再执行put操作,则暂时阻塞直到有空位置腾出来再继续进行,同样当queue为空的时候执行get操作也会暂时阻塞。

multiprocessing中的Queue和queue.Queue功能差不多,但是他能实现多进程之间数据的共享。

Process多进程包。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# train
def train(training_dbs, validation_db, start_iter=0):
...
training_queue = Queue(system_configs.prefetch_size)
validation_queue = Queue(5)
# 创建用于多个进程共享的队列
pinned_training_queue = queue.Queue(system_configs.prefetch_size)
pinned_validation_queue = queue.Queue(5)
# 创建用于当前主线程的数据存储队列
data_file = "sample.{}".format(training_dbs[0],data) # 读取数据的函数
sample_data = importlib.import_module(data_file).sample_data # 动态导入模块

training_tasks = init_parallel_jobs(training_dbs, training_queue, sample_data, True)
# 在init_parallel_jobs函数中是创建了多个进程,并行的读取数据
# def init_parallel_jobs(dbs, queue, fn, data_aug):
# tasks = [Process(target=prefetch_data, args=(db, queue, fn, data_aug)) for db in dbs] # 每个进程的处理函数是prefetch_data, 参数中传递了数据集, 保存结果的队列, 还有读取数据的函数和是否增广的标志。
# for task in tasks:
# task.daemon = True # 将进程设置为保护进程,当主线程结束时,该进程自动结束。
# task.start() # 启动不同的进程
# return tasks
if val_iter:
validation_tasks = init_parallel_jobs([validation_db], validation_queue, dample_data, False)

training_pin_semaphore = threading.Semaphore()
validation_pin_semaphore = threading.Semaphore()
training_pin_semaphore.acquire()
validation_pin_semaphore.acquire()
# 这两句给规整数据的线程申请了信号量

training_pin_args = (training_queue, pinned_training_queue, training_pin_semaphore)
training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args)
# 创建的线程执行函数是pin_memory
# def pin_memory(data_queue, pinned_data_queue, sema):
# while True: # 注意这部分是死循环的,知道sema申请到信号量,也就是线程释放了结束
# data = data_queue.get() # 多进程去填满data_queue队列,然后再该线程中从data_queue中取数据
# data["xs"] = [x.pin_memory() for x in data["xs"]]
# data["ys"] = [y.pin_memory() for y in data["ys"]]
# pinned_data_queue.put(data) # 将多进程得到的数据存放在当前线程队列中
# if sema.acquire(blocking=False): # 没申请到信号量就继续执行循环, 而这个信号量看下面的代码发现是训练结束的时候才会释放,也就是说在训练没结束时,该线程一直在尝试从data_queue中拷贝数据到pinned_data_queue中供使用
# return
training_pin_thread.daemon = True
training_pin_thread.start()
...

with stdout_to_tqdm() as save_stdout:
for iteration in tqdm(range(start_iter+1, max_iteration+1), file=save_stdout, ncols=80):
training = pinned_training_queue.get(block=True)
# 从当前的线程中拿数据,拿不到就阻塞等待
...
training_pin_semaphore.release()
validation_pin_semaphore.release()
# 当训练完成后,释放信号量,这时候准备数据的线程获得信号量,退出循环进而回收资源

for training_task in training_tasks:
training_task.terminate()
for validation_task in validation_tasks:
validation_task.termination()
# 终止数据读取的进程,注意这里进程的终止需要晚于准备数据线程的时间,因为准备数据线程是从进程贡献的队列中拿数据的,如果进程过早的退出, 线程中的队列拿不到数据就会发生阻塞。

这部分代码并行处理大致可以分为3部分。以下面的图示来讲三个阶段大体上是并行的,通过两个队列来进行传递数据。另外与主线程的归属公国daemon和信号量实现:

multiprocessing

本文作者 : zhouzongwei
原文链接 : http://yoursite.com/2019/05/26/centernet-train/
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!

知识 & 情怀 | 赏或者不赏,我都在这,不声不响

微信扫一扫, 以资鼓励

微信扫一扫, 以资鼓励

支付宝扫一扫, 再接再厉

支付宝扫一扫, 再接再厉

留下足迹