PyTorch Lighting框架学习
最近在搞一个挑战赛,但是感觉自己写的代码好难看,很混乱,而且大部分还是复用别人的东西,所以打算系统学习一下这个简洁的框架。
parser
先把这个学了,然后去看那个参数转实例的
参考链接
dataset
看一个知乎的帖子,讲到的一种项目组织方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| root- |-data |-__init__.py |-data_interface.py |-xxxdataset1.py |-xxxdataset2.py |-... |-model |-__init__.py |-model_interface.py |-xxxmodel1.py |-xxxmodel2.py |-... |-main.py
|
我在root下还加了个config.py,上面这种方法,在data_interface和model_interface中,分别写数据集和模型的wrapper,然后在main中用trainer去训练
在mixmatch半监督方法中,会用到标记和未标记两部分数据集,框架支持定义dataloader时返回一个列表或数组:
1 2 3 4 5 6 7 8 9 10 11 12
| from pytorch_lightning.trainer.supporters import CombinedLoader
def train_dataloader(self): loader_a = DataLoader() loader_b = DataLoader() loaders = {"a": loader_a, "b": loader_b} combined_loader = CombinedLoader(loaders, mode="max_size_cycle") return combined_loader
def training_step(self, batch, batch_idx): batch_a = batch["a"] batch_b = batch["b"]
|
在我这个任务中会涉及到好几个数据集,lightning框架的LightningDataModule
提供了一个setup的hook接口,所有的数据集定义,条件选择都在这里做
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| def setup(self, stage=None): if self.dataset == 'TAU': if stage == 'fit': self.trainset = DeltaDataset(self.train_csv, self.fea_path) self.valset = DeltaDataset(self.val_csv, self.fea_path) if stage == 'test': self.testset = DeltaDataset(self.test_csv, self.fea_path) if self.dataset == 'CAS': if stage == 'fit': self.trainset = CASDeltaDataset(self.train_csv, self.fea_path) self.valset = valdataset(self.val_csv, self.fea_path) self.unlabelset = unlabeled_CASDeltaDataset(self.unlabel_csv, self.fea_path) self.iteration = self.unlabelset.__len__()//self.batch_size if stage == 'test': self.testset = valdataset(self.test_csv, self.fea_path)
|
model
model_interface需要做的不仅是包装模型,训练和验证的过程也包含在内,作为LightningModule
的子类MInterface
的类函数。
该类中的几个hook函数如下(不全)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| class MInterface(pl.LightningModule): def __init__(self, model_name:str, lr:float, mode:str, **kargs): super().__init__() self.save_hyperparameters() self.init_model() self.configure_loss() def init_model(self): def forward(self, x): return self.model(x)['logits'] def configure_loss(self): def configure_optimizers(self): def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
|
对于训练过程中需要记录的loss等信息,可以使用自带的log,会调用定义Trainer时传入的logger,有两种方法,单值log和字典log
1 2
| self.log_dict(log_info, on_step=True, prog_bar=True, on_epoch=True, logger=True) self.log("pretrain_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
当参数 prog_bar=True
时,会把这些信息打印到进度条后面,但是有一说一挺丑的。
main
定义好两个interface时候,在main函数只需要先根据参数定义两个interface的实例,然后:
1 2 3 4 5
| data_loder = ... model = MInterface(model_name=args.model, **vars(args)) logger = CSVLogger('./', 'logs') trainer = Trainer(accelerator='cuda', devices=[1], fast_dev_run=False, max_epochs=max_epochs, logger=logger) trainer.fit(model, data_loader)
|
有几个地方要注意
- Trainer的参数中,accelerator就是训练用的device,后面的devices有如下几种情况
- 使用k个设备训练
devices=k
(很重要,是大坑)
- 使用第k个设备训练
devices=[k]
,也可以在列表中定义多个设备
- Trainer中的
fast_dev_run
很好用,设为True
后会把训练验证测试先按照batch=1
跑一轮,验证程序准确性,防止跑完训练验证代码出问题这种很傻逼的情况。