您好,欢迎来到六九路网。
搜索
您的当前位置:首页Python中的Tensor-简单神经网络示例

Python中的Tensor-简单神经网络示例

来源:六九路网

Python中的Tensor

1.Tensor的数据类型

  • torch.FloatTensor
  • IntTensor
  • rand
  • randn
  • range
  • zeros
import torch

a=torch.FloatTensor(2,3)
b=torch.FloatTensor([2,3,4,5])

print(a)
print(b)
import torch
a = torch.IntTensor(2,3)
b = torch.IntTensor([2,3,4,5])

print(a)
print(b)
import  torch
a = torch.rand(2,3)#均匀分布
print(a)
import torch
a = torch.randn(2,3)#正态分布

print(a)
import torch
a = torch.range(1,20,1)
print(a)
import torch
a = torch.zeros(2,3)
print(a)

2.Tensor的运算

  • abs
  • add
  • clamp
  • div
  • mul
  • pow
  • mm
  • mv
import torch
a = torch.randn(2,3)
print(a)

b = torch.abs(a)#输出绝对值
print(b)
import torch
a = torch.randn(2,3)
print(a)

b = torch.randn(2,3)
print(b)

c = torch.add(a,b)#两个变量加和
print(c)

d = torch.randn(2,3)
print(d)

e = torch.add(d,10)#变量和标量的加和
print(e)
#裁剪
import torch
a = torch.randn(2,3)
print(a)

b = torch.clamp(a,-0.1,0.1)
print(b)
#求商
import torch
a = torch.randn(2,3)
print(a)

b = torch.randn(2,3)
print(b)

c = torch.div(a,b)
print(c)

d = torch.randn(2,3)
print(d)

e = torch.div(d,10)
print(e)

#求积
import torch
a = torch.randn(2,3)
print(a)

b = torch.randn(2,3)
print(b)

c = torch.mul(a,b)
print(c)

d = torch.randn(2,3)
print(d)

e = torch.mul(d,10)
print(e)
#求幂
import torch
a = torch.randn(2,3)
print(a)

b = torch.pow(a,2)
print(b)
#矩阵乘法
import torch
a = torch.randn(2,3)
print(a)

b = torch.randn(3,2)
print(b)

c = torch.mm(a,b)
print(c)
#矩阵和向量乘法
import torch
a = torch.randn(2,3)
print(a)

b = torch.randn(3)
print(b)

c = torch.mv(a,b)
print(c)

3.搭建一个简易神经网络

import torch
batch_n=100
hidden_layer = 100
input_data = 1000
output_data = 10

#随机参数初始化
x = torch.randn(batch_n,input_data)
y = torch.randn(batch_n,output_data)

w1 = torch.randn(input_data,hidden_layer)
w2 = torch.randn(hidden_layer,output_data)

epoch_n=20
learning_rate = 1e-6

for epoch in range(epoch_n):
    h1 = x.mm(w1)
    h1 = h1.clamp(min = 0)#将小于0的数赋值为0
    y_pred = h1.mm(w2)#前向传播的结果
    
    loss = (y_pred - y).pow(2).sum()#均方误差
    print("Epoch:{},Loss:{:.4f}".format(epoch,loss))
    
    grad_y_pred = 2*(y_pred - y)#后向传播
    grad_w2 = h1.t().mm(grad_y_pred)
    
    grad_h = grad_y_pred.clone()
    grad_h = grad_h.mm(w2.t())
    grad_h.clamp_(min=0)
    grad_w1 = x.t().mm(grad_h)
    
    w1 -=learning_rate*grad_w1#权重更新
    w2 -=learning_rate*grad_w2

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- 69lv.com 版权所有 湘ICP备2023021910号-1

违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务