迹忆客 专注技术分享

当前位置:主页 > 学无止境 >

如何使用 PyTorch torch.max()

作者:迹忆客 最近更新:2022/09/21 浏览次数:

在本文中,我们将了解如何使用 PyTorch torch.max() 函数。

正如大家所料,这是一个非常简单的功能,但有趣的是,它的功能比想象的要多。

让我们通过一些简单的例子来看看如何使用这个函数。

注意 :在撰写本文时,使用的 PyTorch 版本是 PyTorch 1.5.0


PyTorch torch.max() - 基本语法

要使用 PyTorch torch.max(),首先导入 torch

import torch

现在,此函数返回 Tensor 中元素的最大值。

PyTorch torch.max() 的默认行为

默认行为是返回单个元素和一个索引,对应于全局最大元素。

max_element = torch.max(input_tensor)

下面是一个例子:

p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)

输出

tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
tensor(2.7976)

事实上,这给了我们 Tensor 中的全局最大元素!

沿维度使用 torch.max()

但是,大家可能希望获得沿特定维度的最大值,作为张量,而不是单个元素。

要指定维度(轴 - 在 numpy 中),还有另一个可选的关键字参数,称为 dim

这代表了我们取最大值的方向。

这将返回一个元组 max_elementsmax_indices

  • max_elements -> Tensor的所有最大元素。
  • max_indices -> 对应于最大元素的索引。
max_elements, max_indices = torch.max(input_tensor, dim)

这将返回一个 Tensor,它具有沿维度 dim 的最大元素。

现在让我们看一些例子。

p = torch.randn([2, 3])
print(p)

# 沿 dim = 0 (axis = 0) 获取最大值
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)

输出如下所示

tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])

如你所见,我们找到了沿维度 0 的最大值(沿列的最大值)。

此外,我们得到与元素对应的索引。 例如,0.0688 在第 0 列的索引为 1

同样,如果要沿行查找最大值,请使用 dim=1

# 沿 dim = 1(axis = 1)获取最大值
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)

输出如下所示

tensor([2.7976, 1.4443])
tensor([1, 2])

实际上,我们得到了沿行的最大元素,以及相应的索引(沿行)。


使用 torch.max() 进行比较

我们还可以使用 torch.max() 来获取两个 Tensor 之间的最大值。

output_tensor = torch.max(a, b)

在这里,a 和 b 必须具有相同的维度,或者必须是“可广播的” Tensor。

这是一个比较两个具有相同维度的Tensor的简单示例。

p = torch.randn([2, 3])
q = torch.randn([2, 3])

print("p =", p)
print("q =",q)

# 比较 p 和 q 的元素并得到最大值
max_elements = torch.max(p, q)

print(max_elements)

结果输出如下所示

p = tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
q = tensor([[-0.0678,  0.2042,  0.8254],
        [-0.1530,  0.0581, -0.3694]])
tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688,  0.0581,  1.4443]])

实际上,我们得到的输出 Tensor 在 p 和 q 之间具有最大元素。


总结

在本文中,我们学习了如何使用 torch.max() 函数来找出 Tensor 的最大元素。

我们还使用这个函数来比较两个 Tensor 并获得其中的最大值。

转载请发邮件至 1244347461@qq.com 进行申请,经作者同意之后,转载请以链接形式注明出处

本文地址:

相关文章

Python 中的第一类函数

发布时间:2023/04/25 浏览次数:113 分类:Python

第一类函数是被语言视为对象或变量的函数。 我们可以将它们分配给变量或将它们作为对象传递给其他函数。Python 支持第一类函数的功能。

Python 函数参数类型

发布时间:2023/04/25 浏览次数:140 分类:Python

在这篇 Python 文章中,我们将学习 Python 中使用的函数参数类型。 我们还将学习如何编写不带参数的 Python 函数。

Python 生成器中的 send 函数

发布时间:2023/04/25 浏览次数:111 分类:Python

本教程将介绍如何在 Python 中使用生成器的 send() 函数。我们可以创建一个像迭代器一样运行的函数,并且可以通过 Python 生成器函数在 for 循环中使用。

Python Functools 偏函数

发布时间:2023/04/25 浏览次数:80 分类:Python

本文介绍了我们如何使用分部函数,该函数随 functools 库一起提供,并附有示例。 这显示了调用时如何传递属性和部分函数。

Python main() 函数中的参数

发布时间:2023/04/25 浏览次数:157 分类:Python

在本教程结束时,我们应该了解Python 中在 main() 中使用参数是否是一种好的做法。

Python 中的内置 identity 函数

发布时间:2023/04/25 浏览次数:88 分类:Python

identity 函数只是一个返回其参数的函数。 当我们定义一个恒等函数并赋值时,它会返回该值。在本教程结束时,我们将了解 Python 是否具有内置的 identity 函数。

在 Python 中拟合阶跃函数

发布时间:2023/04/25 浏览次数:177 分类:Python

阶跃函数是带有看起来像一系列步骤的图形的方法。 它们由一系列中间有间隔的水平线段组成,也可以称为阶梯函数。本文给出了阶跃函数的简单演示。

在 Python 中创建双向链表

发布时间:2023/04/25 浏览次数:54 分类:Python

双向链表是指由称为节点的顺序链接的记录集组成的链接数据结构。 每个节点包含一个前一个指针、一个下一个指针和一个数据字段。

将 Python 类对象序列化为 JSON

发布时间:2023/04/25 浏览次数:152 分类:Python

本教程介绍序列化过程。 它还说明了我们如何使用 toJSON() 方法使 JSON 类可序列化,并包装 JSON 以转储到其类中。

扫一扫阅读全部技术教程

社交账号
  • https://www.github.com/onmpw
  • qq:1244347461

最新推荐

教程更新

热门标签

扫码一下
查看教程更方便