PyTorch Lighting框架学习

最近在搞一个挑战赛,但是感觉自己写的代码好难看,很混乱,而且大部分还是复用别人的东西,所以打算系统学习一下这个简洁的框架。

parser

先把这个学了,然后去看那个参数转实例的
参考链接

dataset

看一个知乎的帖子,讲到的一种项目组织方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
root-
|-data
|-__init__.py
|-data_interface.py
|-xxxdataset1.py
|-xxxdataset2.py
|-...
|-model
|-__init__.py
|-model_interface.py
|-xxxmodel1.py
|-xxxmodel2.py
|-...
|-main.py

我在root下还加了个config.py,上面这种方法,在data_interface和model_interface中,分别写数据集和模型的wrapper,然后在main中用trainer去训练

在mixmatch半监督方法中,会用到标记和未标记两部分数据集,框架支持定义dataloader时返回一个列表或数组:

1
2
3
4
5
6
7
8
9
10
11
12
from pytorch_lightning.trainer.supporters import CombinedLoader

def train_dataloader(self):
loader_a = DataLoader()
loader_b = DataLoader()
loaders = {"a": loader_a, "b": loader_b}
combined_loader = CombinedLoader(loaders, mode="max_size_cycle")
return combined_loader

def training_step(self, batch, batch_idx):
batch_a = batch["a"]
batch_b = batch["b"]

在我这个任务中会涉及到好几个数据集,lightning框架的LightningDataModule提供了一个setup的hook接口,所有的数据集定义,条件选择都在这里做

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if self.dataset == 'TAU':
if stage == 'fit':
self.trainset = DeltaDataset(self.train_csv, self.fea_path)
self.valset = DeltaDataset(self.val_csv, self.fea_path)
if stage == 'test':
self.testset = DeltaDataset(self.test_csv, self.fea_path)
if self.dataset == 'CAS':
if stage == 'fit':
self.trainset = CASDeltaDataset(self.train_csv, self.fea_path)
self.valset = valdataset(self.val_csv, self.fea_path)
self.unlabelset = unlabeled_CASDeltaDataset(self.unlabel_csv, self.fea_path)
self.iteration = self.unlabelset.__len__()//self.batch_size
if stage == 'test':
self.testset = valdataset(self.test_csv, self.fea_path)

model

model_interface需要做的不仅是包装模型,训练和验证的过程也包含在内,作为LightningModule的子类MInterface的类函数。
该类中的几个hook函数如下(不全)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class MInterface(pl.LightningModule):
def __init__(self, model_name:str, lr:float, mode:str, **kargs):
super().__init__()
# 这个可以保存每次训练的超参数到yaml文件,很好用
self.save_hyperparameters()
self.init_model()
self.configure_loss()
def init_model(self):
# 初始化模型,看参数然后给self.model一个模型实例
def forward(self, x):
return self.model(x)['logits']
def configure_loss(self):
# 配置损失函数,因为训练用到了很多个损失,就搞了个字典存
def configure_optimizers(self):
# 看上面的链接,这里可以返回一个优化器,或者优化器和scheduler的列表
def training_step(self, batch, batch_idx):
# mixmatch和pretrain的过程的混合,用ifelse区分
# 其实我觉得分别再写两个类函数区分两种训练过程
# 然后在这里调用两个类函数会更好
def validation_step(self, batch, batch_idx):
# 验证过程

对于训练过程中需要记录的loss等信息,可以使用自带的log,会调用定义Trainer时传入的logger,有两种方法,单值log和字典log

1
2
self.log_dict(log_info, on_step=True, prog_bar=True, on_epoch=True, logger=True)
self.log("pretrain_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

当参数 prog_bar=True时,会把这些信息打印到进度条后面,但是有一说一挺丑的。

main

定义好两个interface时候,在main函数只需要先根据参数定义两个interface的实例,然后:

1
2
3
4
5
data_loder = ...
model = MInterface(model_name=args.model, **vars(args))
logger = CSVLogger('./', 'logs')
trainer = Trainer(accelerator='cuda', devices=[1], fast_dev_run=False, max_epochs=max_epochs, logger=logger)
trainer.fit(model, data_loader)

有几个地方要注意

  • Trainer的参数中,accelerator就是训练用的device,后面的devices有如下几种情况
    • 使用k个设备训练devices=k(很重要,是大坑)
    • 使用第k个设备训练devices=[k],也可以在列表中定义多个设备
  • Trainer中的fast_dev_run很好用,设为True后会把训练验证测试先按照batch=1跑一轮,验证程序准确性,防止跑完训练验证代码出问题这种很傻逼的情况。