PyTorch Lightning 专门为机器学习研究者开发的PyTorch轻量包装器

lightning_logo

PyTorch Lightning

专门为机器学习研究者开发的PyTorch轻量包装器(wrapper)。缩放您的模型。写更少的模板代码。

68747470733a2f2f62616467652e667572792e696f2f70792f7079746f7263682d6c696768746e696e672e737667
2
3
4
5
6
7
8


持续集成

系统/ PyTorch版本 1.3(最低标准) 1.4 1.5 (最新)
Linux py3.6 [CPU] 1 1 1
Linux py3.7 [GPU] 222
Linux py3.6 / py3.7 / py3.8 badge badge
OSX py3.6 / py3.7 / py3.8 badge badge
Windows py3.6 / py3.7 / py3.8 badge badge

使用PyPI进行轻松安装

pip install pytorch-lightning

文档

重构您的PyTorch代码+好处+完整演练

tutorial_cover

演示

这是一个没有验证或测试循环的最少代码的示例。

# 这只是一个简单的有一些结构的nn.Module

class LitClassifier(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

# 训练!
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

model = LitClassifier()
trainer = pl.Trainer(gpus=8, precision=16)    
trainer.fit(model, train_loader) 

其他示例
GAN
BERT
DQN
MNIST on TPUs

它是什么?

阅读这个快速入门的网页

Lightning是一种组织PyTorch代码,以使科学代码(science code)与工程分离的方法。它不仅仅是框架,而是PyTorch样式指南。

在Lightning中,您可以将代码分为3个不同的类别:

  1. 研究代码(位于LightningModule中)。
  2. 工程代码(您删除并由trainer进行处理)。
  3. 不必要的研究代码(日志等,这些可以放在回调中)。

这是一个如何将研究代码重构为LightningModule的示例。

pt_to_pl

其余的代码由Trainer自动执行!
pt_trainer

严格测试(Testing Rigour)

每个新的PR都会自动测试Trainer的所有代码。

实际上,我们还使用vanilla PyTorch循环训练了一些模型,并与使用Trainer训练的同一模型进行比较,以确保我们获得完全相同的结果。在此处检查奇偶校验测试

总体而言,Lightning保证了自动化零件经过严格测试,改正,是现代的最佳实践。

它有多灵活?

如您所见,您只是在组织PyTorch代码-没有抽象。

对于Trainer所提取的内容,您可以覆盖任何您想做的事情,例如实现自己的分布式训练,16位精度甚至是自定义的反向传递梯度。

例如,在这里您可以自己进行向后传递梯度

class LitModel(LightningModule):
  def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
                     second_order_closure=None):
      optimizer.step()
      optimizer.zero_grad()

对于您可能需要的其他任何内容,我们都有一个广泛的回调系统,您可以使用它来添加trainer中未实现的任意功能。

Lightning是专门为了谁?

  • 专业研究人员
  • 博士学生
  • 企业生产团队

如果您只是要学习深度学习,我们建议您先学习PyTorch!一旦实现了模型,请回来并使用Lightning的所有高级功能:)

lightning能为我控制什么?

一切都是蓝色的!

这就是lightning将科学(红色)与工程(蓝色)分开的方式。

pl_overview

转换成lightning代码需要多少耗费多少时间?

如果您的代码不是一团糟,那么您应该可以在不到1小时的时间内将其组织成一个LightningModule。如果您的代码一团糟,那么您无论如何都要清理您的代码;)

请查看此分步指南
或观看此视频

开始一个新项目?

使用我们种子项目!该项目的目的是重现代码

为什么要使用lightning?

尽管您的研究/生产项目可能开始很简单,但是一旦添加了GPU和TPU训练,16位精度等功能,最终您将花费比研究更多的时间进行工程设计。Lightning会自动为您进行严格测试。

支持

  • 8位核心贡献者均由专业工程师,研究科学家,和来自顶级AI实验室的博士生组成。
  • 100多个社区贡献者。

Lightning也是PyTorch生态系统的一部分,该生态系统要求项目具有可靠的测试,文档和支持。


README目录


实际例子

这是您将真实的PyTorch项目组织到Lightning中的方法。

pt_to_pl

LightningModule定义了一个系统,例如seq-2-seq,GAN等。它还可以定义一个简单的分类器。

总结来说,您只需要做:

  1. 定义一个LightningModule
    class LitSystem(pl.LightningModule):

        def __init__(self):
            super().__init__()
            # 不是最好的模型...
            self.l1 = torch.nn.Linear(28 * 28, 10)

        def forward(self, x):
            return torch.relu(self.l1(x.view(x.size(0), -1)))

        def training_step(self, batch, batch_idx):
            ...
  1. 给它配备一个Trainer
    from pytorch_lightning import Trainer
    
    model = LitSystem()
    
    # 最基本的trainer, 使用默认值
    trainer = Trainer()
    trainer.fit(model)
    

点击此处查看COLAB演示

哪些类型的研究有效?

任何! 请记住,这只是组织的PyTorch代码。训练步骤定义了训练循环中发生的核心复杂度。

可能像seq2seq一样复杂

# 在这里定义训练的过程
def training_step(self, batch, batch_idx):
    x, y = batch

    # 定义自己的前向传播和计算损失
    hidden_states = self.encoder(x)

    # 甚至与seq-2-seq + attn模型一样复杂
    # (这只是一个用于说明的小示例)
    start_token = '<SOS>'
    last_hidden = torch.zeros(...)
    loss = 0
    for step in range(max_seq_len):
        attn_context = self.attention_nn(hidden_states, start_token)
        pred = self.decoder(start_token, attn_context, last_hidden)
        last_hidden = pred
        pred = self.predict_nn(pred)
        loss += self.loss(last_hidden, y[step])

    # 小示例
    loss = loss / max_seq_len
    return {'loss': loss}

或像CNN图像分类一样

# 在这里定义验证代码
def validation_step(self, batch, batch_idx):
    x, y = batch

    # 或者像CNN分类一样
    out = self(x)
    loss = my_loss(out, y)
    return {'loss': loss}

而且,无需更改一行代码,您就可以在CPU上运行

trainer = Trainer(max_epochs=1)

或者在GPU上运行

# 8个GPU
trainer = Trainer(max_epochs=1, gpus=8)

# 256个GPU
trainer = Trainer(max_epochs=1, gpus=8, num_nodes=32)

或者在TPU上运行

# 分发给TPU进行训练
trainer = Trainer(tpu_cores=8)

# 单个TPU进行训练
trainer = Trainer(tpu_cores=[1])

当您完成训练后,测试准确度

trainer.test()

可视化

Lightning具有流行的日志记录/可视化框架的开箱即用的集成

tf_loss

Lightning使40多个DL / ML研究的部分自动化

  • GPU训练
  • 分布式GPU(集群)训练
  • TPU训练
  • 提前停止
  • 记录日志/可视化
  • 检查点
  • 实验管理
  • 完整清单在这里

例子

查看这份很棒的研究论文列表以及使用Lightning实现的例子。

教程

查看我们的入门指南来开始使用。或直接进入我们的教程


寻求帮助

欢迎来到Lightning社区!

如有任何疑问,请随时:
1. 阅读文档
2. 搜索问题
3. 使用pytorch-lightning标签在stackoverflow进行提问。
4. 加入我们的slack


常见问题

如何使用Lightning进行快速研究?
Here’s a walk-through

为什么创建Lightning?
Lightning有3个目标:

  1. 最大限度地提高灵活性,同时在整个研究项目中抽象出通用样板代码。
  2. 重现性。如果所有项目都使用LightningModule模板,则将更容易了解正在发生的事情以及发生事情的地方!这也意味着每种实现都遵循标准格式。
  3. 使PyTorch高级用户功能民主化。分布式训练?16位?知道您需要它们,但又不想花时间实现?很好…这些功能都内置在Lightning中。

Lightning跟Ignite和fast.ai相比如何?
这是一个彻底的比较

这是我必须要去学习的另一个库吗?
不!我们在任何地方都使用纯的Pytorch代码,并且不会添加不必要的抽象!

有计划要支持Python 2吗?
不。

有计划要支持virtualenv吗?
不。请使用anaconda或miniconda。

conda activate my_env
pip install pytorch-lightning

自定义安装

最前沿

如果您迫不及待想安装下一个发布版本,请使用以下命令安装最新版本:
* 使用GIT(在本地克隆具有完整历史记录的整个仓库)
bash
pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

* 使用即时zip(没有git历史记录的仓库的最后状态)
bash
pip install https://github.com/PytorchLightning/pytorch-lightning/archive/master.zip --upgrade

安装任何发行版

您还可以从存储库中安装任何以前的0.X.Y发行版:

pip install https://github.com/PytorchLightning/pytorch-lightning/archive/0.X.Y.zip --upgrade

Lightning 团队

领导

核心维护者

资金

仅由几个兼职人员构建开源软件是非常困难!我们已获得资金,以确保我们可以聘请专职人员,参加会议并实现您要求的功能,以便更快地前行。

我们的目标是建立一个令人难以置信的研究平台和庞大的社区支持。许多开源项目已经通过诸如对大公司的支持和特殊帮助之类的活动来为运营筹集资金!

如果您是这些公司之一,请随时与will@pytorchlightning.ai进行联系!

Bibtex

如果您想引用框架,请随时引用它(但前提是您喜欢它 :)):

@article{falcon2019pytorch,
  title={PyTorch Lightning},
  author={Falcon, WA},
  journal={GitHub. Note: https://github. com/williamFalcon/pytorch-lightning Cited by},
  volume={3},
  year={2019}
}

原创文章,作者:pytorch,如若转载,请注明出处:https://pytorchchina.com/2020/06/15/pytorch-lightning-%e4%b8%93%e9%97%a8%e4%b8%ba%e6%9c%ba%e5%99%a8%e5%ad%a6%e4%b9%a0%e7%a0%94%e7%a9%b6%e8%80%85%e5%bc%80%e5%8f%91%e7%9a%84pytorch%e8%bd%bb%e9%87%8f%e5%8c%85%e8%a3%85%e5%99%a8/

QR code