博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch学习(一)基础语法篇
阅读量:6357 次
发布时间:2019-06-23

本文共 2499 字,大约阅读时间需要 8 分钟。

how to use pytorch

1.Tensor

we can create a tensor just like creating a matrix the default type of a tensor is float

import torch as ta = t.Tensor([[1,2],[3,4],[5,6]])a
tensor([[1., 2.],        [3., 4.],        [5., 6.]])

we can also change the datatype of a tensor

b = t.LongTensor([[1,2],[3,4],[5,6]])b
tensor([[1, 2],        [3, 4],        [5, 6]])

we can also create a tensor filled with zero or random values

c = t.zeros((3,2))d = t.randn((3,2))print(c)print(d)
tensor([[0., 0.],        [0., 0.],        [0., 0.]])tensor([[ 1.2880, -0.1640],        [-0.2654,  0.7187],        [-0.3156,  0.4489]])

we can change the value in a tensor we've created

a[0,1] = 100a
tensor([[  1., 100.],        [  3.,   4.],        [  5.,   6.]])

numpy and tensor can transfer from each other

import numpy as npe = np.array([[1,2],[3,4],[5,6]])torch_e = t.from_numpy(e)torch_e
tensor([[1, 2],        [3, 4],        [5, 6]])

2.Variable

Variable consists of data, grad, and grad_fn

data为Tensor中的数值

grad是反向传播梯度

grad_fn是得到该Variable的操作 例如加减乘除

from torch.autograd import Variablex = Variable(t.Tensor([1]),requires_grad = True)w = Variable(t.Tensor([2]),requires_grad = True)b = Variable(t.Tensor([3]),requires_grad = True)y = w*x+by.backward()print(x.grad)print(w.grad)print(b.grad)
tensor([2.])tensor([1.])tensor([1.])

we can also calculate the grad of a matrix

x = t.randn(3)x = Variable(x,requires_grad=True)y = x*2print(y)y.backward(t.FloatTensor([1,1,1]))print(x.grad)
tensor([-2.4801,  0.6291, -0.4250], grad_fn=
)tensor([2., 2., 2.])

3.dataset

you can define the function len and getitem to write your own dataset

import pandas as pdfrom torch.utils.data import Datasetclass myDataset(Dataset):    def __init__(self, csv_file, txt_file, root_dir, other_file):        self.csv_data = pd.read_csv(csv_file)        with open(txt_file, 'r') as f:            data_list = f.readlines()        self.txt_data = data_list        self.root_dir = root_dir            def __len__(self):        return len(self.csv_data)        def __getitem(self,idx):        data = (self.csv_data[idx],self.txt_data[idx])        return data

4.nn.Module

from torch import nnclass net_name(nn.Module):    def __init(self,other_arguments):        super(net_name, self).__init__()            def forward(self,x):        x = self.convl(x)        return x

5.Optim

1.一阶优化算法

常见的是梯度下降法\(\theta = \theta-\eta\times \frac{\partial J(\theta)}{\partial\theta}\)

2.二阶优化算法

Hessian法

转载于:https://www.cnblogs.com/ChetTlittilebread/p/10293150.html

你可能感兴趣的文章
用 Go 写一个轻量级的 ssh 批量操作工具
查看>>
网站设计之合理架构CSS 架构CSS
查看>>
OTP 22.0 RC3 发布,Erlang 编写的应用服务器
查看>>
D语言/DLang 2.085.1 发布,修复性迭代
查看>>
感觉JVM的默认异常处理不够好,既然不好那我们就自己来处理异常呗!那么如何自己处理异常呢?...
查看>>
Java 基础 之 算数运算符
查看>>
Windows下配置安装Git(二)
查看>>
一个最简单的基于Android SearchView的搜索框
查看>>
铁路开通WiFi“钱景”不明
查看>>
Facebook申请专利 或让好友及陌生人相互拼车
查看>>
电力“十三五”规划:地面光伏与分布式的分水岭
查看>>
美联社再告FBI:要求公开请黑客解锁iPhone花费
查看>>
三星电子出售希捷和夏普等四家公司股份
查看>>
任志远:当云计算遇上混合云
查看>>
思科联手发那科 用物联网技术打造无人工厂
查看>>
智慧城市首要在政府利用大数据的智慧
查看>>
2015年物联网行业:巨头展开专利大战
查看>>
以自动化测试撬动遗留系统
查看>>
网络安全初创公司存活之道
查看>>
《图解CSS3:核心技术与案例实战》——1.2节浏览器对CSS3的支持状况
查看>>