拦截pytorch算子,dump输入输出

拦截pytorch算子,dump输入输出

  • 一.代码
  • 二.输出

希望dump出pytorch每个算子的输入输出,但pytorch普通的hook机制只能拦截module.以下提供一种方法可以拦截torch.add,torch.Tensor.add这类算子.原理是通过模板替换,劫持torch和torch.Tensor中的算子.遍历next_functions调用register_hook拦截backward.

一.代码

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading
import base64
from jinja2 import Template

device="cuda"

class Attention(nn.Module):
    def __init__(self,max_seq_len,head_dim,flash):
        super().__init__()
        self.flash = flash #hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.dropout=0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.head_dim=head_dim
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)
            mask = torch.triu(mask, diagonal=1).half().to(device)
            self.register_buffer("mask", mask)		
    def forward(
            self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,
                                                                       attn_mask=None, 
                                                                       dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            _xk=xk.clone()
            t=_xk.transpose(2, 3)
            scores = torch.matmul(xq,t)
            scores = scores/math.sqrt(self.head_dim)
            a=self.mask[:, :, :seqlen, :seqlen]
            scores = torch.add(scores,a)
            scores = F.softmax(scores.float(), dim=-1)
            scores = scores.type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  
        return output

lock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):
    if isinstance(args,torch.Tensor):
        print(name,index,args.shape)
        global gindex
        lock.acquire()
        torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))
        gindex+=1
        lock.release()
    if isinstance(args,tuple):
        for idx,x in enumerate(args):
            save_tensor(name,x,index+idx)

op_template=Template('''      
native1_{{new_name}}=getattr(torch.Tensor,'{{name}}')
def {{new_name}}(*args, **kwargs):
    save_tensor("{{name}}-input",args)    
    global native1_{{new_name}}             
    ret=native1_{{new_name}}(*args, **kwargs)
    save_tensor("{{name}}-output",ret)   
    return ret
setattr(torch.Tensor, '{{name}}', {{new_name}})
''')

for op in dir(torch.Tensor):
    if op in ["__iter__","shape","dim","unbind","normal_","data",
                "item","numel","save","has_names","data_ptr","untyped_storage",
                "storage_offset","size","stride","triu","half","is_floating_point",
                "to","ones","randint","ones_like"]:
        continue
    if getattr(torch.Tensor,op).__class__.__name__ not in ["method_descriptor"]:
        continue
    new_name=base64.b64encode(str(f"torch.Tensor.{op}").encode('utf-8')).decode("utf-8").replace("=","")
    exec(op_template.render(name=op,new_name=new_name))

op_template=Template('''      
native2_{{new_name}}=getattr(torch,'{{name}}')
def {{new_name}}(*args, **kwargs):
    save_tensor("{{name}}-input",args)    
    global native2_{{new_name}}             
    ret=native2_{{new_name}}(*args, **kwargs)
    save_tensor("{{name}}-output",ret) 
    return ret
setattr(torch, '{{name}}', {{new_name}})
''')

for op in dir(torch):
    if op in ["is_grad_enabled","__iter__","save","has_names","data_ptr",
              "untyped_storage","storage_offset","size","stride","triu",
              "is_floating_point","to","ones","randint","full","reshape","ones_like"]:
        continue
    if getattr(torch,op).__class__.__name__ not in ["builtin_function_or_method"]:
        continue
    new_name=base64.b64encode(str(f"torch.{op}").encode('utf-8')).decode("utf-8").replace("=","")
    exec(op_template.render(name=op,new_name=new_name))

def hook_backwards(loss, cached):
    if loss is None:
        return    
    def posthook(*args,**kwargs):
        save_tensor(loss.__class__.__name__,args)
    def prehook(*args,**kwargs):
        pass
    loss.register_prehook(prehook)
    loss.register_hook(posthook)
    cached.add(loss)
    for _, child in enumerate(loss.next_functions):
        if child[0] not in cached:
          hook_backwards(child[0],cached)

def main(flash,bs, n_local_heads, seqlen, head_dim):
    torch.random.manual_seed(1)

    q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)

    q.data.normal_(0, 0.1)
    k.data.normal_(0, 0.1)
    v.data.normal_(0, 0.1)

    q=Variable(q, requires_grad=True).to(device)
    k=Variable(k, requires_grad=True).to(device)
    v=Variable(v, requires_grad=True).to(device)

    gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)
    loss_func=nn.CrossEntropyLoss().to(device)

    model=Attention(seqlen,head_dim,flash).half().to(device)
    optim = torch.optim.SGD([q,k,v], lr=1.1)

    for i in range(1):
        output = model(q,k,v)
        loss=loss_func(output.reshape(-1,head_dim),gt)
        hook_backwards(loss.grad_fn, cached=set())
        loss.backward()  
        optim.step()
        print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))

bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

二.输出

reshape-input 0 torch.Size([32768, 1])
reshape-output 0 torch.Size([32768])
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 64, 512])
matmul-input 0 torch.Size([8, 8, 512, 64])
matmul-input 1 torch.Size([8, 8, 64, 512])
matmul-output 0 torch.Size([8, 8, 512, 512])
__truediv__-input 0 torch.Size([8, 8, 512, 512])
__truediv__-output 0 torch.Size([8, 8, 512, 512])
add-input 0 torch.Size([8, 8, 512, 512])
add-input 1 torch.Size([1, 1, 512, 512])
add-output 0 torch.Size([8, 8, 512, 512])
float-input 0 torch.Size([8, 8, 512, 512])
float-output 0 torch.Size([8, 8, 512, 512])
softmax-input 0 torch.Size([8, 8, 512, 512])
softmax-output 0 torch.Size([8, 8, 512, 512])
type_as-input 0 torch.Size([8, 8, 512, 512])
type_as-input 1 torch.Size([8, 8, 512, 64])
type_as-output 0 torch.Size([8, 8, 512, 512])
matmul-input 0 torch.Size([8, 8, 512, 512])
matmul-input 1 torch.Size([8, 8, 512, 64])
matmul-output 0 torch.Size([8, 8, 512, 64])
reshape-input 0 torch.Size([8, 8, 512, 64])
reshape-output 0 torch.Size([32768, 64])
NllLossBackward0 0 torch.Size([32768, 64])
NllLossBackward0 1 torch.Size([])
LogSoftmaxBackward0 0 torch.Size([32768, 64])
LogSoftmaxBackward0 1 torch.Size([32768, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([32768, 64])
UnsafeViewBackward0 0 torch.Size([64, 512, 64])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 64])
BmmBackward0 0 torch.Size([64, 512, 512])
BmmBackward0 1 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 512])
ViewBackward0 1 torch.Size([64, 512, 512])
ExpandBackward0 0 torch.Size([8, 8, 512, 512])
ExpandBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 0 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
AddBackward0 0 torch.Size([8, 8, 512, 512])
AddBackward0 1 torch.Size([8, 8, 512, 512])
DivBackward0 0 torch.Size([8, 8, 512, 512])
DivBackward0 1 torch.Size([8, 8, 512, 512])
UnsafeViewBackward0 0 torch.Size([64, 512, 512])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 512])
BmmBackward0 0 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 64, 512])
BmmBackward0 1 torch.Size([64, 512, 512])
ReshapeAliasBackward0 0 torch.Size([8, 8, 64, 512])
ReshapeAliasBackward0 1 torch.Size([64, 64, 512])
ExpandBackward0 0 torch.Size([8, 8, 64, 512])
ExpandBackward0 1 torch.Size([8, 8, 64, 512])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
TransposeBackward0 0 torch.Size([8, 8, 512, 64])
TransposeBackward0 1 torch.Size([8, 8, 64, 512])
CloneBackward0 0 torch.Size([8, 8, 512, 64])
CloneBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
45.56250,-12.76562,121.68750,4.16016

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/578000.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LMDeploy量化部署LLMVLM实践-笔记五

本次课程由西北工业大学博士生、书生浦源挑战赛冠军队伍队长、第一期书生浦语大模型实战营优秀学员【安泓郡】讲解【OpenCompass 大模型评测实战】课程 课程视频:https://www.bilibili.com/video/BV1tr421x75B/ 课程文档:https://github.com/InternLM/…

Redis高级篇详细讲解

0.今日菜单 Redis持久化【理解】 Redis主从 Redis哨兵 Redis分片集群【运维】 单点Redis的问题 数据丢失问题:Redis是内存存储,服务重启可能会丢失数据 并发能力问题:单节点Redis并发能力虽然不错,但也无法满足如618这样的高…

有什么因素会影响IP稳定性?

IP稳定性指的是IP地址在一段时间内保持不变的能力,对于网络连接的安全性和可靠性至关重要。以下是一些可能影响IP稳定性的主要因素: 网络服务提供商(ISP)的政策:不同的ISP对于IP地址的管理和使用有不同的政策。一些IS…

视频滚动字幕一键批量轻松添加,解锁高效字幕编辑,提升视频质量与观众体验

视频已成为我们获取信息、娱乐休闲的重要渠道。一部成功的视频作品,除了画面精美、音质清晰外,字幕的添加也是至关重要的一环。字幕不仅能增强视频的观感,还能提升信息的传达效率,让观众在享受视觉盛宴的同时,更加深入…

16.Blender 基础渲染工作流程及安装ACES

安装插件和菜单栏设置 在菜单栏的编辑里打开偏好设置 里面的插件界面 搜索node 给第三个打勾 点击安装,导入cat插件 安装完后,一定要打勾,选择上cat插件 这样N窗口才会显示MMD选项 导入场景 点击打开 把输出模式的帧率改为30fps 按…

【redis】Redis数据类型(一)——String类型(包含redis通用命令)

目录 Redis通用命令String类型常用的操作命令一些特殊命令详解setnx示例使用 setrange示例 mset示例 msetnx示例 append示例 getset示例 incr示例使用1.计数器2.限速器 bitcount示例使用:使用 bitmap 实现用户上线次数统计性能 String类型String类型简介String类型的…

LeetCode57. 插入区间

LeetCode57.插入区间 题目思路: 代码 /* 前置知识&#xff1a; vector<vector<int>> a,b; 二维vector数组是可以将二维中的一维vector数组给push_back的&#xff0c; 不是只有单个元素才可以&#xff0c;整个一维的vector数组也可以 b[0] {1,2,3},b[1] {4,5,6}…

AI大模型探索之路-训练篇2:大语言模型预训练基础认知

文章目录 前言一、预训练流程分析二、预训练两大挑战三、预训练网络通信四、预训练数据并行五、预训练模型并行六、预训练3D并行七、预训练代码示例总结 前言 在人工智能的宏伟蓝图中&#xff0c;大语言模型&#xff08;LLM&#xff09;的预训练是构筑智慧之塔的基石。预训练过…

优先级队列(堆)和四个比较(==,equals,Comparable,Comparator)

本节具体阐述堆的概念和自己如何实现堆(底层) 掌握PriorityQueue的使用 优先级概念 前面介绍过队列&#xff0c;队列是一种先进先出(FIFO)的数据结构&#xff0c;但有些情况下&#xff0c;操作的数据可能带有优 先级&#xff0c;一般出队列时&#xff0c;可能需要优先级高的…

Docker共享Nginx配置文件

先去一个容器中&#xff0c;找到Nginx.conf配置文件的目录 去创建一个容器&#xff0c;将容器中存放nginx.conf的目录挂载到宿主机存放nginx.conf目录上 去宿主机中找到nginx/html/index.html目录位置 进入宿主机的index.html中修改页面内容 curl 192.168.91.106访问一下 进入…

Tomcat安装和配置以及多实例部署(附脚本)

TOMCAT详细部署 Tomcat服务器简介核心组件Tomcat 各组件及关系工作流程 Tomcat server.xml 配置详解serverserviceConnectorEngineHostContextValve 阀门 Tomcat部署与安装部署脚本主要目录说明 Tomcat多实例部署扩展和优化 Tomcat 的 catalina.sh 文件以调整 JVM 参数 Tomcat服…

lt Redis变慢的原因及排查解决方法

前言 Redis 作为优秀的内存数据库&#xff0c;其拥有非常高的性能&#xff0c;单个实例的 OPS 能够达到 10W 左右(5-10W)。但也正因此如此&#xff0c;当我们在使用 Redis 时&#xff0c;如果发现操作延迟变大的情况&#xff0c;就会与我们的预期不符。 你也许或多或少地&…

LLaMA-Factory参数的解答(命令,单卡,预训练)

前面这个写过&#xff0c;但觉得写的不是很好&#xff0c;这次是参考命令运行脚本&#xff0c;讲解各个参数含义。后续尽可能会更新&#xff0c;可以关注一下专栏&#xff01;&#xff01; *这是个人写的参数解读&#xff0c;我并非该领域的人如果那个大佬看到有参数解读不对或…

一文扫盲:数据中台,可不是搞几个报表就叫中台。

Hi&#xff0c;我是贝格前端工场&#xff0c;相比大家会经常听说数据中台这个词汇&#xff0c;很多老铁会想当然的人为数据中台就是各种报表&#xff0c;本文给大家纠正和普及一下。 一、什么是数据中台 数据中台是指一个企业内部的数据管理和分发平台&#xff0c;它通过集中…

​解析什么是物联网接入网关?-天拓四方

随着物联网技术的飞速发展&#xff0c;越来越多的设备、传感器和系统被连接到互联网&#xff0c;形成了一个庞大的、相互连接的智能网络。在这个网络中&#xff0c;物联网接入网关扮演着至关重要的角色&#xff0c;它不仅是连接物联网设备和云平台的桥梁&#xff0c;还是实现设…

excel一列同乘同一个数

excel一列同乘同一个数 第一种方法&#xff08;excel本身功能&#xff09; 在空白区域输入要乘以的数&#xff0c;比如0.5 右键选择复制 选中需要乘以的单元格&#xff0c;选择性粘贴 点击乘&#xff0c;选择确定 删除0.5后也不会改变值 第二种方法&#xff08;方方格子…

STM32自己从零开始实操01:原理图

在听完老师关于 STM32 物联网项目的所有硬件课程之后&#xff0c;就是感觉自己云里雾里&#xff0c;明明课程都认真听完了&#xff0c;笔记也认真记录&#xff0c;但是就是感觉学到的知识还不是自己。 遂决定站在老师的肩膀上自己开始设计项目&#xff0c;将知识变成自己的&am…

Lagent AgentLego 智能体应用搭建-笔记六

本次课程由Lagent&AgentLego 核心贡献者樊奇老师讲解【Lagent & AgentLego 智能体应用搭建】课程 课程视频&#xff1a;https://www.bilibili.com/video/BV1Xt4217728/ 课程文档&#xff1a;https://github.com/InternLM/Tutorial/tree/camp2/agent 大语言模型的局限…

MybatisPlus开发业务接口

&#x1f49f;&#x1f49f;前言 ​ 友友们大家好&#xff0c;我是你们的小王同学&#x1f617;&#x1f617; 今天给大家打来的是 MybatisPlus开发业务接口 希望能给大家带来有用的知识 觉得小王写的不错的话麻烦动动小手 点赞&#x1f44d; 收藏⭐ 评论&#x1f4c4; 小王的主…

(mac)Promethues监控之mysqld_exporter(MySQL监控)

搭建Mysqld_exporterPrometheusGrafana监控系统 普罗米修斯是后端数据监控平台&#xff0c;通过Mysqld_exporter收集mysql数据&#xff0c;Grafana将数据用图形的方式展示出来 前提&#xff1a;已安装grafana和promethues 1.下载安装Mysql &#xff08;1&#xff09;启动MySQL…
最新文章