Mooncake Joins PyTorch Ecosystem
12 Feb 2026, 10:38 pmWe are thrilled to announce that Mooncake has officially joined the PyTorch Ecosystem! By integrating Mooncake’s high-performance KVCache transfer and storage capabilities with PyTorch-native inference engines like SGLang, and vLLM, and TensorRT-LLM, we are unlocking new levels of throughput and scalability for large language model deployments.
To view the PyTorch Ecosystem, see the PyTorch Landscape. Learn more about how projects can join the PyTorch Ecosystem.
About Mooncake
Mooncake is designed to solve the “memory wall” in LLM serving. As context lengths grow and models scale, the static binding of Key-Value (KV) cache to specific GPU workers becomes a primary bottleneck.
Mooncake empowers inference engines to break this binding, unlocking four critical capabilities:
- (Encoder) Prefill-Decode Disaggregation: Mooncake’s high-performance Mooncake Transfer Engine separates heavy computation (prefill/encoder) from latency-sensitive generation (decoding) into distinct clusters.
- Global KVCache Reuse: By acting as a distributed shared memory for KV blocks, Mooncake Store enables valid cache to be reused globally across different requests and engine instances.
- Elastic Expert Parallelism: By decoupling experts from specific workers, Mooncake-EP enables elastic and resilient serving where experts of Mixture-of-Experts (MoE) models can be dynamically routed or recovered, ensuring high availability even during partial node failures.
- PyTorch Distributed Backend: Mooncake Backend serves as a fault-tolerant PyTorch distributed backend. It provides robust collective communication primitives capable of continuing operation seamlessly in the presence of rank failures.
- Weighs Updating: Mooncake Store enables rapid weight updates for RL and checkpoint scenarios by storing weights internally. It offers tensor-native and zero-copy APIs.
Wide Industry Adoption
Mooncake originated from a research collaboration between Moonshot AI and Tsinghua University. It was born from the need to solve the “memory wall” in serving massive-scale models like Kimi. Since open-sourcing, it has evolved into a thriving community-driven project.
Mooncake’s architecture has been battle-tested in some of the world’s most demanding production environments. Its ability to decouple compute from memory has led to wide adoption across leading organizations, including Moonshot AI (Kimi), Alibaba Cloud, Ant Group, JD.com, Tencent, Meituan, Approaching.AI and LightSeek Foundation.
These organizations utilize Mooncake to maximize GPU utilization and ensure smooth serving for millions of concurrent users.
In Action: A Joint Solution
To demonstrate the full potential of this architecture, we present a joint solution that combines Mooncake with the ecosystem’s leading inference engines and orchestration tools.
In this architecture, we will use RoleBasedGroup (RBG, https://github.com/sgl-project/rbg) to orchestrate the entire topology, defining the relationships and startup order of the cluster. It deploys Shepherd Model Gateway (SMG, https://github.com/lightseekorg/smg) as the critical routing layer, which intelligently directs incoming requests to the appropriate workers based on cache locality and system load. The heavy lifting is then performed by SGLang (https://github.com/sgl-project/sglang) or vLLM (https://github.com/vllm-project/vllm) instances serving as compute workers, while Mooncake functions as the high-speed data plane: its Transfer Engine pushes prefilled KV cache via RDMA/NVLink, and its Store persists that cache for global reuse by decoding nodes.
1. Deployment with SGLang + Mooncake + SMG
Below is the RBG configuration that immediately deploys a complete SGLang architecture. In this case, both Prefill-Decode Disaggregation and Global KVCache-Reuse are enabled. The Prefill instances utilize Mooncake TE to transfer kvcache to Decode instances, while Mooncake Store facilitates reusing KVCache across different requests within the Prefill instance (more details in KEP-74 Mooncake Integration and pd-disaggregated-with-mooncake.yaml).
YAML # Joint Solution: RBG + SMG + SGLang + Mooncake (Production Ready) apiVersion: workloads.x-k8s.io/v1alpha1 kind: RoleBasedGroup metadata: name: sglang-mooncake-smg-v2 spec: roles: # 1. Mooncake Master: Centralized Metadata Server for TE and Store - name: mooncake-master replicas: 1 template: spec: containers: - name: master image: lmsysorg/sglang:latest env: - name: POD_IP valueFrom: fieldRef: fieldPath: status.podIP command: ["mooncake_master"] args: - --enable_http_metadata_server=true - --rpc_address=$(POD_IP) - --rpc_port=50051 - --http_metadata_server_host=$(POD_IP) - --http_metadata_server_port=8080 - --metrics_port=9003 # 2. Mooncake Store: Distributed KVCache Storage Nodes - name: mooncake-store replicas: 3 dependencies: ["mooncake-master"] template: spec: containers: - name: store-node image: lmsysorg/sglang:latest env: - name: MOONCAKE_MASTER value: "s-sglang-mooncake-smg-v2-mooncake-master:50051" - name: MOONCAKE_TE_META_DATA_SERVER value: "http://s-sglang-mooncake-smg-v2-mooncake-master:8080/metadata" - name: MOONCAKE_GLOBAL_SEGMENT_SIZE value: "45gb" - name: MOONCAKE_PROTOCOL value: "rdma" # Use RDMA for zero-copy KVCache transfer command: ["python3", "-m", "mooncake.mooncake_store_service"] resources: limits: memory: "50Gi" rdma/hca: 1 # Required for high-speed TE transfer requests: memory: "50Gi" rdma/hca: 1 # 3. Prefill Worker (SGLang): High-throughput Prefill with Mooncake Push - name: prefill-worker replicas: 1 dependencies: ["mooncake-master", "mooncake-store"] template: spec: containers: - name: sglang-prefill image: lmsysorg/sglang:latest env: - name: MOONCAKE_MASTER value: "s-sglang-mooncake-smg-v2-mooncake-master:50051" - name: MOONCAKE_TE_META_DATA_SERVER value: "http://s-sglang-mooncake-smg-v2-mooncake-master:8080/metadata" - name: MOONCAKE_PROTOCOL value: "rdma" command: - python3 - -m - sglang.launch_server - --model-path /models/Qwen3 - --tp 4 - --disaggregation-mode prefill - --disaggregation-transfer-backend mooncake # Activates Mooncake TE for KVCache Push - --enable-hierarchical-cache # Enables KVCache offloading - --hicache-storage-backend mooncake # Uses Mooncake as the L2/L3 cache backend resources: limits: nvidia.com/gpu: "4" rdma/hca: 1 # 4. Decode Worker (SGLang): Low-latency Generation with Mooncake Pull - name: decode-worker replicas: 2 dependencies: ["mooncake-master", "prefill-worker"] template: spec: containers: - name: sglang-decode image: lmsysorg/sglang:latest command: - python3 - -m - sglang.launch_server - --model-path /models/Qwen3 - --tp 4 - --disaggregation-mode decode # Pulls shared KVCache from Mooncake Store resources: limits: nvidia.com/gpu: "4" rdma/hca: 1 # 5. Shepherd Model Gateway (SMG): Intelligent PD-Disaggregation Router - name: smg-router replicas: 1 dependencies: ["prefill-worker", "decode-worker"] template: spec: containers: - name: router image: lightseekorg/smg:latest command: - smg - --pd-disaggregation - --prefill http://s-sglang-mooncake-smg-v2-prefill-worker:8000 - --decode http://s-sglang-mooncake-smg-v2-decode-worker:8000 - --host 0.0.0.0 - --port 8000
2. Deployment with vLLM + Mooncake
vLLM has also integrated Mooncake support, allowing users to leverage Mooncake connectors for seamless KV transfer. Below is the equivalent rbg() solution for deploying vLLM in a disaggregated setup using Mooncake connectors.
YAML # Joint Solution: RBG + vLLM + Mooncake Connector apiVersion: workloads.x-k8s.io/v1alpha1 kind: RoleBasedGroup metadata: name: vllm-pd-with-mooncake-demo spec: roles: # 1. Gateway: Routing to vLLM instances (SMG or vLLM Proxy) - name: proxy dependencies: [ "prefill", "decode" ] replicas: 1 template: spec: containers: - name: proxy image: lightseekorg/smg:latest command: - smg - --prefiller-host - http://vllm-pd-with-mooncake-demo-prefill-0.s-vllm-pd-with-mooncake-demo-prefill - --prefiller-port - "8000" - --decoder-host - http://vllm-pd-with-mooncake-demo-decode-0.s-vllm-pd-with-mooncake-demo-decode - --decoder-port - "8000" # 2. Prefill Worker (vLLM): Producer role - name: prefill replicas: 1 template: spec: volumes: - name: model persistentVolumeClaim: claimName: qwen2.5-7b - name: dshm emptyDir: medium: Memory sizeLimit: 30Gi containers: - name: prefill image: vllm/vllm-openai:latest command: - sh - -c - | pip install mooncake-transfer-engine && \ vllm serve /models/Qwen2.5-7B-Instruct \ --port 8000 \ --tensor-parallel-size 4 \ --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}' ports: - containerPort: 8000 name: http readinessProbe: initialDelaySeconds: 30 periodSeconds: 10 tcpSocket: port: 8000 resources: limits: nvidia.com/gpu: "4" rdma/hca: 1 memory: "100Gi" requests: nvidia.com/gpu: "4" rdma/hca: 1 memory: "100Gi" volumeMounts: - mountPath: /models/Qwen2.5-7B-Instruct name: model - mountPath: /dev/shm name: dshm # 3. Decode Worker (vLLM): Consumer role - name: decode replicas: 1 template: spec: volumes: - name: model persistentVolumeClaim: claimName: qwen2.5-7b - name: dshm emptyDir: medium: Memory sizeLimit: 30Gi containers: - name: decode image: vllm/vllm-openai:latest command: - sh - -c - | pip install mooncake-transfer-engine && \ vllm serve /models/Qwen2.5-7B-Instruct \ --port 8000 \ --tensor-parallel-size 4 \ --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}' ports: - containerPort: 8000 name: http readinessProbe: initialDelaySeconds: 30 periodSeconds: 10 tcpSocket: port: 8000 resources: limits: nvidia.com/gpu: "4" rdma/hca: 1 memory: "100Gi" requests: nvidia.com/gpu: "4" rdma/hca: 1 memory: "100Gi" volumeMounts: - mountPath: /models/Qwen2.5-7B-Instruct name: model - mountPath: /dev/shm name: dshm --- apiVersion: v1 kind: Service metadata: labels: app: vllm-pd-with-mooncake-demo name: vllm-pd-with-mooncake-demo namespace: default spec: ports: - name: http port: 8000 protocol: TCP targetPort: 8000 selector: rolebasedgroup.workloads.x-k8s.io/name: vllm-pd-with-mooncake-demo rolebasedgroup.workloads.x-k8s.io/role: proxy type: ClusterIP
Conclusion
Mooncake adds a vital layer of memory virtualization to the open-source AI stack. By enabling PyTorch engines—whether SGLang, vLLM, or TensorRT-LLM —to adopt KVCache-centric architectures, we are paving the way for more efficient, scalable, and lower-latency LLM services.
We invite you to explore the project and start building:
- Mooncake GitHub: https://github.com/kvcache-ai/Mooncake
Mooncake Project Doc: https://kvcache-ai.github.io/Mooncake/
Pyrefly Now Type Checks PyTorch
12 Feb 2026, 8:30 pmWe’re excited to share that PyTorch now leverages Pyrefly to power type checking across our core repository, along with a number of projects in the PyTorch ecosystem: Helion, TorchTitan and Ignite. For a project the size of PyTorch, leveraging typing and type checking has long been essential for ensuring consistency and preventing common bugs that often go unnoticed in dynamic code. Migrating to Pyrefly brings a much needed upgrade to these development workflows, with lightning-fast, standards-compliant type checking and a modern IDE experience. With Pyrefly, our maintainers and contributors can catch bugs earlier, benefit from consistent results between local and CI runs, and take advantage of advanced typing features. In this blog post, we’ll share why we made this transition and highlight the improvements PyTorch has already experienced since adopting Pyrefly.
Why Switch to Pyrefly?
To support the future development of PyTorch, we wanted a type checker that is fast, easy to use, consistent across developer environments, and actively maintained. These factors ultimately influenced the decision to move forward with Pyrefly.
Balancing Speed with Accuracy
In a recent round of benchmarking type checking Pytorch took 50.6 seconds using MyPy, whereas Pyrefly (v44.1) took only 5.5 seconds. This is a significant speed improvement over Pytorch’s existing tooling while still maintaining robust type safety. We wanted an alternative that not only delivered fast results, but would also help our contributors catch bugs early and identify gaps in our type coverage. Pyrefly appears to strike the right balance for us, being fast enough to keep up with our development speed without compromising on the quality of type safety.
That said, we see this as just the beginning; there is still room for Pyrefly to become even faster, and we expect to benefit from even greater speed gains as the tool continues to evolve. We’ll be closely following Pyrefly’s ongoing development and look forward to integrating future performance enhancements as they become available.
Simplified Configuration
Previously, our reliance on MyPy required contributors to juggle multiple configuration files to manage coverage and strictness levels across the codebase. This made it difficult to determine exactly which files were being checked and under what specific rules. Transitioning to Pyrefly has helped address these challenges. With direct support from the Pyrefly team, PyTorch has now transitioned to use a single unified Pyrefly configuration and required suppressions, making it much easier for our maintainers to understand which files are being typechecked and how.
Consistency across Development Environments
Previously, developers often encountered discrepancies between their IDE, local CLI, and the CI environment because different type-checking engines were being used at each stage. MyPy might be used in PyTorch CI jobs, but when it comes to IDEs, other type checkers were preferred that behaved slightly differently. Or developers would have a different MyPy strictness mode enabled for their CLI runs that differed from what was used in CI. These inconsistencies led to unpredictable feedback loops and a frustrating experience where code that passed their local type checking run would fail in CI. By adopting Pyrefly, which provides a high-quality IDE experience alongside robust CLI and CI functionality, PyTorch developers can now benefit from consistent results across all their development environments.
| Before | After | |
| CI | MyPy (full project run) | Pyrefly |
| CLI | MyPy (only on select files) | Pyrefly |
| IDE | Pyright OR other | Pyrefly |
Active Maintenance and Rapid Development
Another major reason for migrating is that Pyrefly is actively maintained and evolving quickly, with significant room for continual performance improvements. We’ve appreciated the responsiveness to user feedback and the rapid development cycles, which include new minor releases every Monday. It’s not uncommon for a bug to be reported and resolved in time for the very next release, ensuring that issues are addressed and new features are delivered promptly. An example of this is described in a recent Pyrefly blog post, where a performance bottleneck was identified and promptly resolved, resulting in an 18x speed up in IDE responsiveness across the PyTorch codebase.
Throughout this migration, and as we continue using Pyrefly, our priority is to avoid regressions in type safety or developer experience. Maintaining a regular line of communication with the Pyrefly team has been essential for quickly addressing edge cases and enabling a smooth transition for our contributors.
Additional Benefits for PyTorch Contributors
PyTorch contributors and maintainers have already experienced meaningful improvements since moving to Pyrefly. Beyond the initial motivations for the transition, other benefits include the following:
Improved code quality
The rollout of Pyrefly has already led to the discovery and resolution of numerous bugs in the PyTorch codebase. One factor that helped achieve this was due to the fact that Pyrefly runs in a consistent mode across Pytorch. Take the code example below: unless MyPy is in strict mode, it doesn’t type check the bodies of untyped functions, meaning errors like this would possibly go unnoticed. Pyrefly, on the other hand, runs in one consistent mode across the codebase and is able to catch these types of errors.
def foo():
return 1 + "" # pyrefly error
Seamless IDE Experience
Pyrefly integrates natively with many major IDEs, bringing real-time type feedback, hover documentation, and instant diagnostics directly into the editor that match your local and CI results. Now PyTorch contributors using a diverse range of IDEs can spot type errors as they code and be confident their results are consistent, reducing context-switching and making it easier to maintain high code quality. VSCode users can download our IDE extension here. Once enabled, it will automatically find the configuration file in the PyTorch project.
Advanced Typing Capabilities
Pyrefly brings advanced typing features to PyTorch, including robust support for complex typing patterns and strict adherence to Python typing specifications. This empowers contributors to write safer and more expressive code, while maintaining performance and a smooth developer experience.
Pyrefly’s inference capabilities can also enable developers to detect type errors even in code that lacks explicit type annotations. This means that legacy code, experimental modules, and fast-moving prototypes can benefit from increased type safety, without requiring a massive upfront investment in annotation. It can also help identify areas of code that could benefit from more explicit type annotations, helping us move forward with our goals of increasing type coverage in the codebase. Currently, return type inference is not enabled by default in PyTorch, but we are actively working to add annotations and fix type issues in order to un-gate this feature in the near future.
def foo():
return 1
foo() + "hello" # mypy: no error, # pyrefly: error [unsupported-operation]
Get Started with Pyrefly
Contributors to PyTorch can get started using Pyrefly by installing the extension in their editors, and can start using it for local type checking quickly and easily using lintrunner:
lintrunner init
lintrunner
Contributors to Helion can also get started by installing the IDE extension and can do a local type check by running the repository’s lint.sh file
./lint.sh install && ./lint.sh
Pyrefly is also integrated into our CI suite under the lint job to ensure consistency across the codebase. This ensures that the same rules applied during local development are enforced on every PR. When you open a pull request, you can find the Pyrefly results by navigating to the “Checks” tab and selecting the lint job.
If you’re not a PyTorch contributor but still want to check out Pyrefly on your own project, you can get the VSCode extension here or check out the Pyrefly documentation.
Future Work
Switching to Pyrefly marks a practical and meaningful advancement for the PyTorch project. Developers are already seeing the benefits of faster and more consistent type checking, and the initial rollout has helped uncover and resolve a substantial number of bugs. This transition has streamlined workflows and laid the foundation for ongoing improvements in both code quality and developer experience.
Looking ahead, we hope to continue seeing performance improvements from Pyrefly as the tool matures. We’re also excited to partner with the Pyrefly team to further improve typing across the codebase. Strengthening type annotations in one of the most widely used AI/ML libraries will enable maintainers and the broader community to more confidently leverage PyTorch in production environments. Deploying a newer, faster type checker with Pyrefly is only the first step of that journey.
As always, community feedback is invaluable. We encourage PyTorch contributors and users to share their experiences, report issues, and suggest improvements as we continue refining the type checking workflow. If you have questions or wish to provide feedback to the Pyrefly team, you can do so in Discord, or submit bug reports by opening a GitHub issue in the Pyrefly repository.
Finally, we want to extend our sincere thanks to both the PyTorch and Pyrefly teams, as well as the community, for their feedback and testing throughout this transition.
Why I’m Joining the PyTorch Foundation
11 Feb 2026, 4:55 pm
I want to start by thanking Matt White for everything he has built over the past two years. The growth of the PyTorch Foundation speaks for itself. What began as a single-project foundation is now a multi-project home for some of the most critical infrastructure in AI. That did not happen by accident. It is the result of real technical leadership, genuine community investment, and a clear belief in open collaboration. Matt is now stepping into the role of Global CTO of AI at the Linux Foundation and will transition to the role of CTO at the PyTorch Foundation, where he will focus on the technical strategy and direction that will define what’s possible next.
I’m thrilled to be joining the PyTorch Foundation as its new Executive Director. Here’s why.
The Most Important Open Source Projects in the World
There is not a more important open source project in the world right now than PyTorch. The daily onslaught of new state-of-the-art models proves it. When you hear about models writing compilers from scratch capable of compiling the Linux kernel, you’re getting a glimpse of the future that PyTorch makes possible.
But here’s what I think people outside of our community are only beginning to understand: the PyTorch Foundation is no longer just about PyTorch.
vLLM has become the inference engine of choice for the industry. When a new model drops, it runs on vLLM on day one, which tells us where the center of gravity lives. Inference is the largest workload in human history, and it runs on a PyTorch Foundation project.
DeepSpeed is pushing the boundaries of training efficiency at a scale that was unthinkable a few years ago. Ray is powering the orchestration and scaling layer that lets AI workloads run across the industry. These are foundational technologies with massive communities of their own, and they chose to make their home here.
Training. Inference. Orchestration. The critical layers of the AI stack live under one roof.
Every Innovation Story Is an Infrastructure Story
I’ve spent my career finding the infrastructure layer of emerging technology waves and building open source ecosystems around them. I co-founded OpenStack in 2010 and built the OpenStack Foundation (now OpenInfra Foundation), spending over a decade helping create the open source cloud. Last year we merged the OpenInfra Foundation with the Linux Foundation, and I became General Manager of AI and Infrastructure and Executive Director of the LF AI and Data Foundation. Now I get to put that experience into action with the PyTorch Foundation.
If there’s one thing I’ve learned across all of that, it’s that every innovation story is an infrastructure story if you know where to look. AI is going to reshape every aspect of the lives of every human being on earth, and it is going to do so at a speed that makes previous technological transitions look slow. The industrial revolution played out over generations. The internet transformed society over decades. AI is compressing that arc into years. The infrastructure that makes all of this possible is being built right now, in the open, by the communities in this foundation
We don’t want any one company or country to dominate such critical technologies. They have to be built together by communities that trust each other enough to do the hard work side by side. The best open source foundations foster the conditions that let communities lead. They keep the path open for the widest possible participation and the largest possible impact. That’s what we need to do again, and I’m here to do that work with all of you.
The Energy Is Real
I had the opportunity to attend PyTorchCon in San Francisco last October, and I was in awe of the community energy in that place. That’s not easy to pull off in Moscone, and it’s not something you’ll find at just any open source conference. I’ve been to many of them. It reminded me deeply of the early OpenStack days when our summits were doubling every year, and people were genuinely having fun while changing the world.
If you’re part of this community, whether you contribute to PyTorch, vLLM, DeepSpeed, Ray, or the ecosystem around them, you may not fully realize it yet, but that’s exactly what you’re doing. Enjoy the ride.
What Comes Next
My prime directive is clear. Serve the communities that make this foundation what it is. Advocate for the open path that leads to the most innovation, the widest impact, and the largest number of people served by this technology. And make sure that every community that calls this foundation home knows that it belongs here and that its work matters.
If you’re headed to a PyTorch Conference, a PyTorch Day, or anywhere else this community gathers, come find me. I want to meet the people doing this amazing work. The best part of open source has always been the people, and I can’t wait to get to know more of you.
Let’s go build the future.
Mark Collier is Executive Director of the PyTorch Foundation, General Manager of AI and Infrastructure at the Linux Foundation, and Executive Director of the LF AI and Data Foundation. He co-founded the OpenStack project in 2010 and spent 13 years building the OpenStack Foundation and open source cloud community.
PyTorch Foundation: The Next Chapter, Together
11 Feb 2026, 4:55 pm
Over the past nearly two years, I’ve had the privilege of serving as Executive Director of the PyTorch Foundation. As I look back on what we have accomplished together, one thing stands out clearly: our momentum is not accidental. It is the result of a global community of maintainers, contributors, researchers, practitioners, member organizations, and volunteers who have chosen collaboration, openness, and technical rigor as the path to progress.
This post is both a thank you and a transition update, shared first and foremost with the PyTorch community.
What we built in a short time
In a relatively short period, the PyTorch Foundation has evolved from a single-project foundation centered on PyTorch into a multi-project home for critical components across the AI development lifecycle. Today, the Foundation proudly hosts four major projects: PyTorch, vLLM, DeepSpeed, and most recently Ray. Alongside these hosted projects, the broader PyTorch ecosystem has expanded to more than 100 projects, including Unsloth, verl, SGLang, FEAST, and many other high-quality open source efforts that are pushing the state of the art forward.
At the same time, our membership has grown to 33 organizations, nearly doubling, and we updated our membership tiers to better reflect the scale and maturity of our ecosystem. Those member commitments matter, because they translate into real investment in the shared infrastructure and community programs that enable open source AI to thrive.
Stronger governance and deeper technical collaboration
As our technical scope expanded, so did our governance. We launched the initial Technical Advisory Council and supported its growth into a more active forum for cross-project alignment. We also established five core working groups: CI Infrastructure, Multi-Cloud, Ecosystem, Accelerators, and Security.
These groups are where hard, practical problems get solved: keeping CI reliable and scalable, improving portability and cost efficiency, coordinating cross-project priorities, strengthening security posture, and making it easier for developers and organizations to adopt and deploy PyTorch and related projects. The result has been measurably increased technical engagement, clearer project roadmaps, and more consistent collaboration patterns across the Foundation’s hosted projects and the broader ecosystem.
A bigger global footprint, powered by the community
The growth of PyTorch is global, and our community programs have expanded accordingly.
We grew from a conference of roughly 300 attendees to a flagship PyTorch Conference in San Francisco that welcomed more than 3,000 participants. We successfully launched PyTorch Days with events in Paris and Beijing, and we are continuing to expand our global presence. In 2026, we will hold three PyTorch Conferences: Europe in Paris (April), China in Shanghai (September), and our flagship event, North America in San Jose (October). These will be complemented by additional PyTorch Days, starting in Bengaluru this past weekend, with more events in development, including Beijing, Seoul, and others.
We also launched the PyTorch Ambassadors program, now approaching 50 ambassadors, with another cohort planned. This is one of the most important community programs we run, because it scales something no single team can manufacture: local leadership. Ambassadors host meetups, welcome new contributors, and help PyTorch show up meaningfully in regions and communities around the world. In parallel, we’ve been building a speaker bureau to connect domain experts from the community with events seeking credible technical speakers.
Academic outreach, research engagement, and education
Another area of focus has been strengthening ties between research, education, and open source practice.
We kicked off an Academic and OSPO outreach program to engage academic labs and university Open Source Program Offices, with early work involving UC Berkeley, UC Santa Cruz, Stanford, the University of Vermont, and Caltech. The goal is to help students build practical open source skills, create clearer pathways from research to production, and identify emerging open source AI projects that could benefit from Foundation support.
We also increased the Foundation’s participation in major research and practitioner venues, supporting workshops, posters, and talks at MLSys, ICML, NeurIPS, and UC Berkeley’s AgentX program. Across the year, I joined many leaders from the PyTorch community in speaking at more than 100 events worldwide to advocate for PyTorch, the Foundation, and open source AI as a durable strategy for innovation.
Finally, the educational output from the community has been exceptional. In 2025, we published more than 130 pieces of educational content, including tutorials, webinars, and blogs, averaging nearly one substantive item every three days. That pace reflects both the depth of expertise across the community and the rate at which the ecosystem continues to evolve.
We also made meaningful progress toward scalable professional development. At the last PyTorch Conference, we kicked off onsite training for the PyTorch Certified Associate program with strong participation. In the coming months, we expect to publish the corresponding exam and online course, and then begin building the content pathway toward a PyTorch Certified Professional designation. The intent is to support developers who want to demonstrate practical PyTorch fluency, while giving employers a clearer signal for hiring and workforce development.
Infrastructure that scales with the ecosystem
Behind every reliable open source ecosystem is infrastructure that works. Over the past two years, we continued strengthening CI reliability and observability, expanded monitoring and logging, and progressed the migration of our download site to the Cloudflare CDN.
Just as importantly, the Foundation’s CI would not be sustainable without the support of member organizations and partners who contribute engineering effort, hardware, and operational expertise. Contributions, current and in progress, from Meta, AWS, AMD, Intel, Microsoft, and NVIDIA have been critical. We have also advanced a multi-cloud strategy so we can diversify our footprint across hyperscalers and neo-clouds, manage cost, and maintain the performance and scale that developers and production users depend on.
What comes next
Even with this progress, the next phase demands more. Key priorities ahead include:
- Expanding the hosted project portfolio, including adjacent domains such as agentic AI, environments, and reinforcement learning
- Further diversifying and optimizing CI architecture and costs
- Onboarding additional project CI workloads where shared accelerator access unlocks faster iteration
- Expanding training and certification into a durable revenue stream that strengthens Foundation sustainability
- Deepening community programs, including initiatives such as mentorship and stronger global enablement
As the scope grows, there is a straightforward operational reality: leadership capacity must scale so that organizational throughput, not leadership bandwidth, sets our pace.
A leadership transition to support the next stage
To support this next stage, I’m sharing a leadership transition that takes effect immediately.
I will be stepping into the role of Chief Technology Officer for the PyTorch Foundation, alongside my new role as Global CTO of AI at the Linux Foundation. At the same time, Mark Collier will join the PyTorch Foundation as our new Executive Director.
Mark brings deep experience building and scaling open infrastructure ecosystems, including founding OpenStack and the OpenInfra Foundation. As Executive Director, he will lead the operational and business execution of the Foundation, working closely with the Governing Board. His responsibilities include oversight of Foundation committees (including Finance and Marketing), community programs such as Ambassadors, Foundation-led events, staff management, finances, and membership development. Ultimately, he will be accountable for the overall direction and operations of the Foundation in partnership with the Governing Board.
As CTO, I will focus on technical strategy and execution across the Foundation: supporting the TAC and working groups; advancing our hosted projects and ecosystem alignment; strengthening CI and multi-cloud infrastructure; and driving technical programs, including Academic and OSPO outreach and PyTorch Certified. This structure is intended to increase clarity, accountability, and speed, while preserving community-led technical governance.
Quotes
“It’s great to see the PyTorch Foundation enter a new phase, just months after it evolved into an umbrella foundation. With Mark as the Executive Director and Matt as the CTO, the foundation acquires the level of maturity required by its ambitions. I can’t wait to help build the future of PyTorch with the new leadership and the rest of the TAC.”
– Luca Antiga, CTO, Lightning AI and Chair, PyTorch Foundation Technical Advisory Council (TAC)
“Watching the PyTorch Foundation grow into an umbrella ecosystem has been inspiring—it’s set PyTorch up not only for the short term, but for a long arc of impact foundational to AI. Congrats to Matt on an incredible chapter, and a warm welcome to Mark. I’m excited for where we take PyTorch next!”
– Joe Spisak, Product Director, Meta Superintelligence Labs & PyTorch Core Maintainer
“The growth of the PyTorch Foundation speaks for itself. Thanks to Matt White for everything he has built. What began as a single-project foundation is now a multi-project home for some of the most critical infrastructure in AI. That did not happen by accident. It is the result of real technical leadership, genuine community investment, and a clear belief in open collaboration. I’m excited to keep that momentum going that will define what’s possible next.”
– Mark Collier, Executive Director, PyTorch Foundation
Thank you
I want to close with an explicit note of appreciation. The PyTorch Foundation’s progress is not the product of any single organization or individual. It is the result of thousands of community members: maintainers, contributors, reviewers, working group participants, event organizers, speakers, educators, and member company teams who consistently choose collaboration over fragmentation and long-term stewardship over short-term advantage.
Thank you for the trust, the effort, and the standards you bring to this community.
I’m excited for what comes next, and I’m particularly looking forward to working with Mark as he steps into the Executive Director role. Please join me in welcoming him and supporting him as he begins this next chapter with us.
We have built something strong. Now we scale it.
Matt White
CTO, PyTorch Foundation
Global CTO of AI, Linux Foundation
PyTorch Day India 2026: A builder-focused milestone for open source AI in Bengaluru
10 Feb 2026, 10:35 pmPyTorch Day India 2026: A builder-focused milestone for open source AI in Bengaluru
On February 7, 2026, the inaugural PyTorch Day India brought the open source AI community to Bengaluru for a full day of technical talks, discussions, and community connection. Co-organized by IBM, NVIDIA, and Red Hat, the event reinforced a clear theme: India is not only adopting AI at scale, it is helping define how production-grade, open AI systems are built.
The in-person event was held in Bengaluru, placing the event at the center of one of India’s most active engineering ecosystems with 460 in-person attendees. The event was emceed by Raghu Ganti, IBM, who guided the day’s flow and helped keep the program cohesive and engaging.
Keynotes that framed the day: open platforms shaping the future of open source AI
A strong keynote trio anchored the event, reflecting the co-organizers’ complementary strengths across enterprise platforms, infrastructure software, and accelerated computing.
- Steve Watt (Red Hat): “Any Model, Any Accelerator, Any Cloud: How Open Source AI Unlocks the World’s Potential”
- Sriram Raghavan (IBM): “The Ubiquitous AI Platform: Lessons from Linux, Vision for PyTorch”
- Niket Agarwal (NVIDIA): “Full Stack AI Innovation: PyTorch + NVIDIA From Edge to Data Center”
Taken together, these talks pointed to the operational reality of modern AI, and how platforms like PyTorch and vLLM are laying the foundation for fast innovation from research to production
First, heterogeneous compute has become the default. Teams increasingly mix CPUs, GPUs, and specialized accelerators across cloud and on-premises environments, and they need frameworks and tooling that work consistently across those targets.
Second, AI platforms are increasingly treated as foundational infrastructure. The Linux comparison is instructive because long-lived platforms succeed when they provide stable interfaces, clear governance, and predictable behavior. That stability enables fast iteration above the platform and efficient optimization below it.
Third, end-to-end performance is now a primary product requirement, not an optional enhancement. “Edge to data center” captures the range of deployment patterns that organizations must support, from constrained inference at the edge to large-scale training, fine-tuning, and high-throughput serving in the data center.
What builders came for: kernels, compilers, inference, and real systems work
PyTorch Day India was deliberately technical and builder-oriented. The event emphasized low-level kernel to systems performance work, optimization, training efficiency, inference, and deployment concerns. This reflects where the field is heading: the hardest problems are about making robust and dependable AI systems under real constraints like latency, cost, reliability, security, and governance.
That builder emphasis also showed up in the ecosystem representation and talk topics. For example, Aritra Roy Gosthipaty and Sayak Paul (Hugging Face) highlighted kernel-level work in the Transformers ecosystem, signaling that practical performance engineering is now a first-class conversation for mainstream ML teams.
This focus matches how organizations deploy AI today. Most production AI is not a single model running in isolation. It is a workflow that connects data pipelines, distributed execution, training and evaluation, inference and serving, monitoring, and governance. As these components become more interdependent, open and composable building blocks become essential.
A keynote message worth carrying forward: open source is how AI becomes dependable
In his address, Matt White, PyTorch Foundation CTO and former Executive Director, emphasized a shift that is now common across enterprises. AI is moving from prototype to operational capability. That transition forces teams to prioritize engineering fundamentals, including reproducible training and evaluation, scalable inference, distributed compute and data pipelines, and security and supply-chain hygiene.
He also underscored a broader architectural trend: AI systems are becoming “systems of systems,” where models connect to retrieval, tooling, deployment, monitoring, and governance. In that environment, open source becomes a practical necessity because production adoption benefits from transparency, inspectability, and integration flexibility across complex infrastructure.
Why India matters to the PyTorch Foundation and the global ecosystem
India’s importance to the PyTorch ecosystem is structural.
It has developer scale, talent density, and a strong builder culture that translates research into production systems. It also has broad industry diversity, spanning global capability centers, fast-growing startups, academic institutions, and large enterprises serving both local and international markets. That mix accelerates feedback on what matters most in real deployments.
India also has a mature relationship with open source collaboration. That matters because open ecosystems thrive when communities do more than consume software. They improve it, document it, test it, build extensions, and create learning pathways that expand participation. Events like PyTorch Day India strengthen those pathways by turning knowledge-sharing into sustained contribution.
What comes next: build locally, contribute globally
The most practical takeaway from the inaugural PyTorch Day India is that open source AI maturity is being shaped in many places at once, and India is clearly one of those places. Bengaluru was an appropriate setting, with its dense overlap of research, infrastructure engineering, product development, and startup execution.
For attendees, the next step is straightforward and high leverage: turn one idea from the day into an artifact that others can use. That might be a reproducible benchmark, a tutorial, a bug fix, a performance investigation, a documentation improvement, or a small but meaningful contribution to a project you rely on.
PyTorch Day India 2026 kept the focus where it belongs: on builders, on systems, and on the open technologies that make AI usable across industries and across the world. We hope to launch more PyTorch Day events in India and work together to build a strong, talent-rich, diverse, and cohesive open source AI ecosystem with the community in India.
Accelerating Mamba2 with Kernel Fusion
6 Feb 2026, 10:48 pmSummary
In this post, we discuss how we optimized the Mamba-2 State-Space Dual (SSD) module with a fused Triton kernel that yields speedups of 1.50x-2.51x on NVIDIA A100 and H100 GPUs. To achieve this, we fused all five SSD kernels into a single Triton kernel with careful synchronization. To our knowledge, this is the first end-to-end Triton fusion of all five SSD kernels. This reduces launch overhead and avoids redundant memory operations, making the kernel faster across all input sizes. The rest of this blog will cover how we fused the SSD kernels, what bottlenecks remain, benchmark results, and our plans to release the kernel in the open source so the community can benefit.

Figure 1. Fused SSD Triton Kernel A100 and H100 Speedups
Background
Mamba-2 is a sequence model based on the state-space duality (SSD) framework, which connects structured state-space models (SSMs) with attention-based transformers as an optimized successor to the original Mamba model. One key advantage of Mamba-style models is scalability to long sequences. Mamba’s state-space mechanism scales linearly with context length. In practice, doubling the input sequence length roughly doubles Mamba’s compute and memory needs, whereas self-attention would quadruple them. This makes Mamba-2 especially attractive for extremely long contexts, such as 128K tokens and beyond.
IBM’s Granite 4.0 model family recently adopted a hybrid architecture that combines Mamba-2 blocks with transformer blocks. In Granite 4.0, nine Mamba-2 layers are used for every one attention layer to handle long-range context efficiently. With Mamba-2 becoming integral to such models, optimizing Mamba-2’s performance is critical for faster inference. The core of Mamba-2’s computation is the SSD module, which replaces the attention mechanism in each layer. The original Mamba2 SSD implementation is mostly bottlenecked by memory bandwidth and latency and includes writing and reading intermediate data, so there are opportunities for improvement. In this blog, we focus on accelerating this SSD prefill operation with an optimized fused kernel.
Mamba2 Operations
The operations that make up a typical Mamba2 block are listed in Table 1. We focused on fusing the five SSD kernels because they behave as one conceptual SSD operation, though further fusion (e.g., convolution and layernorm) may be possible as discussed later.
| Layernorm | Helps with numerical stability |
| In Projection | Projects input to SSD channels/size |
| Depthwise Convolution | Mixes the last few tokens |
| SSD Chunk Cumsum | Computes the dt per token and cumulative decay within a chunk |
| SSD Chunk State | Computes the state at the end of this chunk in isolation |
| SSD State Passing | Computes the global states at the end of each chunk |
| SSD BMM | Computes how the each chunk of input x affects the corresponding chunk of output y |
| SSD Chunk Scan | Computes each chunk of y from the corresponding chunk of x and previous chunk’s global state |
| Layernorm | Helps with numerical stability |
| Out Projection | Projects output to the model’s hidden dim |
Table 1. Mamba2 operations
Why Do We Need Kernel Fusion?
During prefill, which is the forward pass over the prompt or input sequence before token generation, Mamba-2’s SSD module executes as a pipeline of five GPU kernels. In the original implementation, these five kernels run sequentially on the GPU.
However, launching multiple small kernels in sequence incurs significant overhead and prevents the GPU from reusing data between stages efficiently. By applying kernel fusion we can get several key benefits:
- Eliminating Kernel Launch Overheads: One launch instead of five reduces CPU-GPU synchronization and scheduling delays.
- Improving Cache Locality: Data produced in one stage is immediately consumed by the next within the same threadblock, increasing cache hits and reducing global memory traffic.
- Overlapping Computation: Different parts of the fused kernel can execute in parallel (where independent), better utilizing GPU resources.
Our solution fuses all five kernels into a single Triton kernel, so that the entire SSD prefill computation for a layer happens within one GPU launch.
Efficient Kernel Fusion Technique
Unlike a simple matmul + activation fusion, SSD fusion is complex because the computation spans multiple steps with complicated dependencies. The original implementation relied on implicit synchronization across kernels, which disappears when we fuse everything. In this section, we discuss why that matters and our approach to making fusion work in practice.
The five steps of the Mamba2 SSD were originally implemented as five separate kernels: Chunk Cumsum, BMM, Chunk State, State Passing, and Chunk Scan, which operate on fixed-size chunks of tokens. The figure below illustrates the dependencies between these kernels.

Figure 2. Mamba2 SSD Prefill Kernel Graph
The State Passing step has dependencies between chunks, and the original State Passing kernel handled this by looping over chunks within threadblocks and splitting the state’s channels across threadblocks for parallelism. With this State Passing loop and the implicit global synchronization between kernel launches, all dependencies were handled in the original kernels.
The real technical challenge comes when we try to fuse all five kernels into a single launch. Once fused, we lose the implicit global synchronization that the original kernels relied on, so we must explicitly manage both within-chunk and across-chunk dependencies. Most of the dependencies are between different steps but the same chunk, so for the three largest kernels, Chunk State, State Passing, and Chunk Scan, these intra-chunk dependencies could be handled by running all steps of a particular chunk on the same threadblock. This would also give us the ability to keep intermediate data between steps in registers or L1 cache (private to each SM) since the data will be used on the same threadblock.
However, this approach is neither possible nor correct. The original State Passing kernel has the aforementioned loop, which makes its threadblock grid not match the original Chunk State and Chunk Scan kernels. Furthermore, having separate threadblocks for each chunk would remove the natural synchronization and correctness provided by looping over chunks within a single threadblock.
To make fusion possible, we split the iterations of the State Passing loop across chunks into separate threadblocks so the threadblock grids match. We get correctness by ordering these threadblocks with atomics, a form of serialization that looks quite inefficient on the surface but can be mitigated by overlapping with the other two parts.
For example, if we ran 8 chunks in parallel, we would expect a ~8x local slowdown from the State Passing serialization. However, the fused State Passing is a small fraction of the three large steps, especially since it no longer has to read the state from global memory (it’s already in the threadblock from the fused Chunk State).
By Amdahl’s law, we would expect the runtime to change to (State Passing fraction) * 8 + (1 – State Passing fraction) * 1. For example, if the State Passing step was only 1/7th of the combined time excluding synchronization, we would get (1/7) * 8 + (6/7) * 1 = 2, implying a 2x overall slowdown. However, this does not account for overlap. Since the synchronization of State Passing can overlap with the Chunk State and Chunk Scan computation, the slowdown would be roughly:
State Passing compute time + max(other compute time, State Passing synchronization time)
= 1/7 + max(6/7, 1/7 * 7) = 1.14x
If State Passing was a smaller fraction of the total runtime or if less chunks are processed concurrently, we could theoretically avoid any serialization slowdown in all but the first chunks.

Figure 3. State Passing Overhead Overlap
Figure 3 shows the theoretical synchronization delays, which are high for the first chunks run in parallel, but settle down to a low overhead in all later chunks. We can see that although chunk 8 depends on chunk 7, it only has to busy-wait 1 unit of time instead of 8 since the chunk 0 Chunk Scan and chunk 8 Chunk State overlap with the State Passing of chunks 1-6. In practice, NVIDIA Nsight Compute benchmarks show that fewer than 3% of warp stalls (idle thread time) are caused by the State Passing synchronization, implying that the serialization latency is hidden.
The BMM and Chunk Cumsum steps are extremely fast compared to the other three. BMM splits work along ngroups instead of nheads, and Chunk Cumsum has its threadblocks handle multiple heads for efficiency. For simplicity, we launch separate threadblocks for these two steps (the first few threadblocks work on them) and have the threadblocks for the other three steps await their BMM and Chunk Cumsum dependencies with atomics.
When a threadblock begins executing the kernel, it is assigned to work on the Chunk Cumsum step unless all Chunk Cumsum work has already been assigned. Similarly, if there is no unassigned Chunk Cumsum work, the threadblock would be assigned to the BMM step if available. After both of these fast steps have been fully assigned to threadblocks, later threadblocks each start processing a chunk in Chunk State, process that same chunk in State Passing, and finally output that chunk after Chunk Scan.
While kernel fusion improves data reuse and speeds up the SSD, additional optimizations are necessary to achieve maximum performance. These include reordering threadblocks to hide serialization latency, adding cache hints to loads/stores to prioritize reused data, separating special cases outside of the fused kernel to reduce register pressure, changing some intermediate datatypes, tuning the chunk size, and restructuring operations for less latency. These optimization techniques are described in more detail in Appendix A.
Remaining Bottlenecks
In this section, we analyze the bottlenecks in the optimized fused SSD kernel using Nsight Compute to examine the final utilization, stall patterns, and resource tradeoffs.
At a high level, we can look at the compute and memory utilization of the fused kernel to get an idea of what limits this kernel.

Figure 4. A100 Nsight Compute Summary

Figure 5. H100 Nsight Compute Summary
We can see that overall fused SSD compute utilization is about 40-50% and memory utilization is about 65-75%. It is not possible to achieve 100% utilization due to the initial load/store latency and other overheads, but it’s usually possible to get at least 80% in a well-optimized kernel. For context, the H100 and A100 matmuls used in Mamba2 get 85-96% compute utilization. Since neither compute nor memory has good utilization in the SSD kernel, the bottlenecks are more complicated than just memory bandwidth or compute throughput.
We can look at the warp state statistics to see what warps are stalled on. “Selected” means that the warp executed a new instruction, but “Stall Long Scoreboard” and “Stall Barrier” indicate that warps are idle waiting for L2/VRAM or synchronizing.

Figure 6. Warp State Statistics for the fused SSD kernel on an H100
There are a few ways to reduce the effect of these stalls and improve the compute or memory utilization:
- Increase occupancy
- Increase instruction-level parallelism
- Optimize the code to use less synchronization and memory ops or cache data better
Occupancy
Modern NVIDIA GPUs have 12-16 warps (groups of 32 threads) per warp scheduler, and each of these warp schedulers can issue a new instruction every cycle. If we only have 1 warp in each scheduler, we waste cycles every time that the warp stalls. However, if we have 16 warps in each scheduler, each warp could be stalled about 15/16 of the time without leaving the hardware idle. Occupancy is the fraction of available warp slots that are actually filled with active warps. Increasing occupancy helps hide memory and instruction latency, increasing GPU utilization.

Figure 7. Occupancy for the fused SSD kernel on an H100
This fused kernel only gets 25% occupancy in the current config, limited by registers and shared memory. Although we can increase the number of warps and reduce the registers per thread to increase occupancy, this reduces performance in practice, likely due to increased synchronization costs and higher register pressure.
Instruction-Level Parallelism
Instruction-Level Parallelism means designing/optimizing the code to have less immediate dependencies between instructions, allowing the warp to run future instructions even when the previous instructions haven’t finished. This provides the same latency-hiding benefit as increased occupancy, but without requiring more warps.
Reducing Synchronization and Data Transfer
Since the warps are usually waiting on loading/storing memory or a barrier, we can improve performance by reducing the amount of barriers or reducing total data transfer through better caching or different block sizes.
Unfortunately, these three optimization techniques can directly clash and introduce tradeoffs. Each SM in the GPU has limited registers and shared memory, so if each threadblock uses too much, occupancy drops. We can increase instruction-level parallelism by loading data in stages, but that requires more registers and shared memory, resulting in lower occupancy. We can also change block sizes to reduce the total data transferred or increase the cache hit rates, but this also requires more resources and reduces occupancy.
This is why the fused kernel does not have very high memory or compute utilization.
Memory Utilization Details

Figure 8. Memory Chart for the fused SSD kernel on an H100
We can see from this chart that the reported 65–75% memory utilization is mostly from reads through the L2 cache. These reads likely include (i) tensors that fit in L2, (ii) tensors reused across multiple threadblocks, (iii) state transfers between threadblocks, and (iv) VRAM reads that naturally pass through L2. Since L1 caches are private to each SM and not coherent across threadblocks, shifting this traffic to L1 is not feasible. Similarly, bypassing L2 for VRAM traffic would not help, as all global memory accesses pass through L2.
This memory chart suggests that, apart from the suboptimal memory utilization, the kernel is effectively L2-bound rather than DRAM-bound. Further optimization would therefore require either (1) increasing memory utilization, (2) tuning the block sizes / config, or (3) making radical algorithmic changes.
Line-by-Line Stalls
Nsight Compute profiling shows warp stalls line-by-line, helping us check that the warp stalls are for legitimate reasons. As expected, most warp stalls in the fused kernel are from loading data, synchronization, and computation, with only minor overheads from atomics and inter-chunk synchronization. See Appendix B for more details.
Benchmarks
We benchmarked our Triton kernel on typical inference scenarios, batch size 1-32, sequence lengths from 1K up to 256K tokens, and fp16 states. These graphs highlight the speedup of our kernel over the baseline unfused kernels.

Figure 9. NVIDIA A100 Fused Kernel Speedup Graph

Figure 10. NVIDIA H100 Fused Kernel Speedup Graph
The fused SSD kernel is 1.50x-2.51x faster than the unfused implementation on the SSD portion. At low sequence lengths (especially with batch=1), overheads from kernel launches help the fused kernel, but these constant costs become amortized for longer sequences. At higher sequences, the fused kernel’s lower data movement is even more beneficial as cache thrashing increases. The SSD speedup translates to roughly a 8-13% end-to-end speedup for a model like Mamba-2 2.7B with batch=1 and seq=128K on NVIDIA A100 and H100 GPUs. At shorter sequence lengths, the end-to-end speedup can reach ~20% at 1K context, likely due to the reduced kernel launch overhead.
Accuracy and Correctness
The fused kernel is generally accurate and correct, but there are slight differences in output between the fused kernel and reference solution. These differences depend on the GPU it’s running on and the precisions of some computations. The fused kernel internally uses fp16 for some computations that the original kernels used fp32 for, because this gives a ~16% speedup. Furthermore, the original kernels support either fp32 or fp16 states, but our reported speedups are for fp16 states. The fused kernel still supports the same intermediate datatypes and fp32 states. In this section we explain the tradeoffs in accuracy and performance for these different dtype configs.
In Table 2, we report the accuracy of the output y tensor as percentage of elements that match the original kernels’ output. We test with no threshold (element must exactly match), a small threshold of 1e-3 absolute and relative tolerance, and a medium threshold of 1e-2. In this table, “exact dtypes” refers to using the same dtypes as the original kernel for all calculations, while “relaxed dtypes” refers to using fp16 for a few calculations. Both the fused and original kernels were run with the same state dtype in each column.
| fp32 states
exact dtypes |
fp16 states
exact dtypes |
fp32 states
relaxed dtypes |
fp16 states
relaxed dtypes |
|
| Match @ atol,rtol=0 | 99.696% | 99.337% | 67.307% | 66.823% |
| Match @ atol,rtol=1e-3 | 100.000% | 100.000% | 99.819% | 99.743% |
| Match @ atol,rtol=1e-2 | 100.000% | 100.000% | 100.000% | 100.000% |
Table 2. H100 Accuracy Table
Floating point addition is not perfectly associative, so we cannot expect all elements of the output tensor to match with 0 threshold. Even a different Triton launch config can cause very small differences in outputs from the same kernel. For “exact dtypes” (both fp16 and fp32 states), the output is identical for all practical purposes, so this kernel should work with “exact dtypes” even in the most accuracy-sensitive models. For “relaxed dtypes” (which we use in our speedup graphs), we can see that around 1/3 of the elements do not perfectly match the output of the original kernel. However, over 99.7% of the output elements match if we allow the tight threshold of 1e-3. Furthermore, at the commonly-used tolerance of atol=1e-2, rtol=1e-2 (1%), all configurations achieve >99.9995% accuracy, effectively 100%. For practical purposes, we expect the “relaxed dtypes” to have indistinguishable accuracy.

Figure 11. H100 fp32 vs fp16 Accuracy Graph
In Figure 11, we show how our speedup changes when states are in fp32 instead of fp16. Both the fused and original kernels are faster with chunk_size=256 when states are in fp32. This represents a tradeoff of higher compute in return for a smaller state tensor. The fused kernel’s speedup is less for fp32 states than fp16 states, likely because of the different balance of compute and data movement.
Other Architectures
The fused SSD kernel is not limited to Mamba-2. It also applies directly to linear attention, since the SSD formula reduces to the linear attention update when A = 1. In this special case, the fused kernel could be further simplified and optimized for even better performance.
New GPU Features
The fused SSD kernel does not currently use newer GPU features such as the Tensor Memory Accelerator (TMA) and thread block clusters on Hopper GPUs, or the Tensor Memory in Blackwell GPUs. These features can greatly reduce register pressure, which would speed up the SSD and could result in faster Triton configs being possible (e.g., larger block sizes). The thread block clusters could especially be useful for broadcast-loading C, B, and CB matrices that are shared across a group of heads in the SSD kernel. This could give further speedups on new GPUs if necessary.
Further Fusion: Convolution and Layernorm
In this fused SSD kernel, we fused the 5 original SSD kernels. However, the convolution before the SSD and layernorm after the SSD are appealing candidates for fusion because fusing each would remove an entire read and write between kernels. Since the convolution is depth-wise (no channel mixing), the SSD could load d_conv extra along the seqlen dimension and load the conv weights to perform the convolution in registers or shared memory.
We have done some experiments with fusing the layernorm, but with limited benefit. There are two methods to fuse this layernorm:
- Launch layernorm threadblocks separately. These threadblocks can wait until the corresponding SSD threadblocks have finished and then read the output y from L2 cache instead of VRAM.
- Sync SSD threadblocks across heads, exchange norm values, and compute the layernorm in registers or shared memory.
Method 2 was very slow because the SSD threadblocks stalled while syncing and had no other work to do while waiting. Method 1 worked, but reading from L2 instead of VRAM doesn’t provide as much benefit as registers/shared memory. So far, the speedup has been far below the theoretical limit, and it’s unclear whether further optimizations would make it worthwhile given the added complexity.
Insights on Model Design
With the optimized fusion of the five SSD kernels, Mamba2 prefill is now even cheaper than before. This shifts the runtime-accuracy tradeoff for Mamba2 layers, which could make scaling up both the size and the number of Mamba2 layers the optimal balance in new LLMs. More design insights include:
- Compute Intensity: The current fused kernel has low compute utilization at the fastest chunk size, so we might be able to afford slightly more complicated operations. Although we could increase compute intensity by increasing the chunk size, that also increases the required registers and other resources, causing an overall slowdown.
- State Precision: In both the fused and original kernels, the State Passing step must be serial instead of parallel. Although sublinear latency parallel scan algorithms exist, in practice, they can be much slower than the serialized version used in Mamba2. Therefore, minimizing the latency of the State Passing computation as a fraction of the total latency is vital to hiding the serialization latency. If the states can be held in low precisions, such as fp16, this significantly helps the fused kernel. Without a fast State Passing step, we might need to split threadblocks more along other dimensions such as headdim, which would slow down the fused kernel overall.
- VRAM vs L2 tradeoff: Since the fused kernel has higher L2 bandwidth utilization than VRAM bandwidth utilization, the cost of sharing less data across threadblocks is less. If an architecture’s performance benefits greatly from smaller groups, the added VRAM reads could have less of a negative impact on performance than it had with the original kernels. On the other hand, new GPU features such as TMA multicast loads could reduce the L2 bandwidth utilization, speeding up the SSD and reducing this imbalance.
vLLM Integration
In order to support variable length sequences with initial states but without padding, vLLM introduces the idea of “pseudo chunks”. Any chunk with tokens for multiple sequences in it has multiple pseudo chunks, one for each sequence in that chunk. Most of the 5 kernels function the same, with State Passing loading initial states when a new sequence starts. However, Chunk Scan has a larger threadblock grid that goes over pseudo chunks instead of chunks. In order to support this in the fused kernel, we have a for loop to process all pseudo chunks in the current chunk. The vLLM Chunk Scan offset its reads and writes based on where the pseudo chunk starts in the real chunk. We use masking based on the sequence index instead, since masking provides a speedup. Both offsetting and masking read/write the same amount of data at runtime, but the masking might be more predictable for the compiler, better aligned, or just simpler. The vLLM fused kernel is still being integrated, but it shows similar speedup.
Conclusion
In summary, we fused the five Triton kernels of the Mamba-2 SSD prefill into one, yielding a 2x speedup for the SSD itself, which translates into a ~8–20% end-to-end inference speedup. This significantly boosts throughput for models using Mamba-2 layers. We are excited to integrate these kernel improvements into open-source projects so that the community can easily leverage faster inference with Mamba-2 based models. Stay tuned for updates as this fused SSD kernel lands in the Mamba codebase and in inference frameworks like vLLM.
Appendix A: Optimization Details
Threadblock Order
The State Passing step causes serialization. For a given head, all but one threadblock stall waiting for the previous chunk to be ready. When our GPU runs about 256-1024 threadblocks concurrently but only one makes progress, we get a significant slowdown. Some of the serialization is hidden by the latency of the Chunk State step since later chunks could still be computing Chunk State rather than being stalled in State Passing, but this is not enough. We have both the nheads and batch dimensions that represent domain parallelism (independent work) in the SSD. Instead of launching threadblocks for a particular batch and head before moving on to the next, we can launch threadblocks for multiple (batch, head) combinations. If we launch n different (batch, head) combinations for the same chunk before moving on to the next chunk, our serialization drops by a factor of n (instead of only 1 threadblock making progress, n threadblocks make progress). This n must be carefully balanced, because if it’s too large, we lose L2 cache locality for passing states, and if it’s too small, threadblocks stall. As a simple heuristic, we launch threadblocks for all nheads before moving on to the next chunk, but finish all chunks before progressing in the batch dimension. For models with much more or less heads or significantly different dimensions, a more complicated threadblock order could involve explicitly combining nheads and batch and then splitting it into an inner and outer dimension, with the inner dimension launching before the next chunk.
Cache Hints
The input and output tensors of operations such as the Mamba2 SSD are typically too large to fit in cache. For example, the input and output for 16k context in a Mamba2 SSD with 128 heads of 64 dim each in fp16 will each consume 16k * 128 * 64 * 2B = 256 MiB. Typical GPU L2 caches are 40-50 MiB. Therefore, some data will be evicted from the L2 cache during that kernel.
Since most of the output tensor does not fit in the L2 cache, it’s not worth using L2 cache capacity for the output to try to speed up the next operation. We can use a cache hint to indicate that the output tensor has the lowest priority for caches. In general, once we access data for the final time in the kernel, we can mark it as low priority for caches. For often reused data, such as CB (which is shared among heads in a group), we can use a high priority cache hint to reduce the chance of eviction.
We can also avoid flushing L1 cache during some sync atomics by specifying “release” semantics. This tells the compiler that previously written data must be globally visible before the atomic operation (e.g. if we are setting a “ready” flag), but this thread does not need to invalidate any caches.
Conditional Separation
In the State Passing step, we have two special cases: reading the initial state instead of the previous chunk’s global state and writing to the final state instead of to the global states tensor. Although conceptually these special cases should only involve swapping the base pointer to read/write to, the initial and final state conditionals increase register pressure and slow down the fused kernel. To solve this, we can handle the special cases outside of the fused SSD kernel. If we replace the nchunks dimension in our state tensor with nchunks + 1, we can copy the initial states into the 0th chunk and copy out final states from the last chunk. These copies are done using the pytorch sliced assignment syntax, which results in small kernels with negligible runtime or launch overhead.
Intermediate Datatypes
For some computations, such as applying the A decay to B in Chunk Scan, we can use fp16 for the computation instead of fp32. This also swaps upcasting B and downcasting the result with only downcasting the scale, reducing casting instructions.
Compile-Time Masks
Triton requires that the dimensions of blocks of tensors in a threadblock are powers of 2 known at compile time. This forces all stores and loads to operate on power-of-2 blocks that might not divide the target tensor exactly. We therefore use masks to cover the entire tensor but avoid reading or writing out of bounds data (or the next block of data). These masks are the same dimensions as the tensor block. However, these masks are not always necessary because model dimensions like headdim are often divisible by the block size and do not change between different inputs. Triton supports tl.constexpr compile-time parameters and setting them based on other parameters with @triton.heuristics. Therefore, we can automatically enable or disable the headdim dimension of the mask at runtime based on if the headdim is divisible by the block size. Although this occurs at “runtime”, it really only occurs once during the initial JIT compilation of the kernel for this model.
Chunk Size
The Mamba2 SSD algorithm takes asymptotically constant computation per token (computation scales linearly with sequence length), but it has a base case of some chunk size that is computed quadratically. Between chunks, the linear algorithm is used, but within a chunk, the quadratic algorithm is used. For more details, see https://tridao.me/blog/2024/mamba2-part1-model/#state-space-duality.
The optimal chunk size represents a tradeoff of higher computation and resources required vs higher hardware utilization and less intermediate states. With the original unfused kernels, the optimal chunk size for Mamba2-2.7B had been 256. However, with the new fused kernel, the optimal chunk size is now 128 for the same model. This smaller chunk size also has the added benefit of reducing register pressure, making the kernel less sensitive to small changes like enabling masks or using higher precision for intermediate results.
Currently, the convention for Mamba2 models is to specify the chunk size in the model’s config. However, since the optimal chunk size varies depending on the original vs fused kernels, it could be better to use a heuristic or autotune the chunk size. This might not be straightforward since the code surrounding the SSD kernels might assume a particular chunk size.
Scale Multiplication Operand
For Chunk State, we can equivalently apply the A decay to X instead of B, since the dimension to be scaled is the inner dimension of the matmul of X and B. Essentially, we do (X * A[None, :]) @ B instead of (X @ (A[:, None] * B). This is faster, probably due to a more similar layout causing less register data movement. For example, due to the required Tensor Core data layout, each thread might already have the required A values to multiply with its X values, but to scale B, we might have to load in a different layout and shuffle data back to the required Tensor Core layout.
Appendix B: Summary of Stall Reasons
If we look at the source in NVIDIA Nsight Compute, we can see the warp stalls for each line of code and assembly instruction in the fused kernel on an H100. Assuming that the kernel and block sizes are optimal, warp stalls can reveal potential areas for optimization.
- In order to ensure correctness, we use an atomic add to get threadblock ids in increasing order. This accounts for about 3% of the total warp stalls.
- Both the Chunk Cumsum and BMM parts of the fused kernel are very fast, so they only cause less than 2% of warp stalls each.
- Atomically checking that the Chunk Cumsum and BMM threadblocks have prepared data for this Chunk State threadblock accounts for about 1.5% of warp stalls.
- Chunk State has about 12% of total warp stalls in loading dA, X, and especially B. It also has about 7% stalls in barriers related to scaling and using Tensor Cores.
- Despite being serialized along chunks, State Passing has less than 3% stalls on synchronization (including awaiting the previous chunk). Loading the previous states does not cause significant stalling, but updating the state and storing cause about 6% stalls awaiting shared memory or a barrier.
- For the previous state’s contribution in Chunk Scan, loading C is about 5% loading stalls, prev_states is about 3% barrier stalls, and the computation is about 8% barrier, loading (for scale), and instruction dependency stalls.
- The current chunk’s contribution in Chunk Scan has about 13% stalls in loading data and 18% stalls in computation (including scaling).
- The residual (scaled by D) accounts for about 6% of total stalls for loading, shared memory, and computation.
Overall, these stalls are for legitimate reasons and are not easy to optimize away.
Some Matrix Multiplication Engines Are Not As Accurate As We Thought
6 Feb 2026, 10:15 pmWhat is an accumulator in an accelerator’s GEMM engine and why does it matter?
GPUs and custom accelerators include specialized compute engines for matrix multiplication (also known as matmul or GEMM), such as NVIDIA’s Tensor Cores. These engines efficiently perform matmul on small tensor blocks; therefore, compilers or libraries typically divide large matmul problems into many smaller ones and feed them to these engines. Usually, the output from a Tensor Core of FP8 (e4m3) matmul with the shape of (block_size_m, block_size_k) and (block_size_k, block_size_n) is a (block_size_m, block_size_n) tensor in FP32 (e8m23). However, one interesting thing users rarely noticed is that for hardware efficiency reasons, this FP32 output could have fewer than 23 effective mantissa bits. In other words, the precision of this Tensor Core operation is lower than FP32 as it appears. This hardware design choice has been reported to impact model accuracy under certain circumstances 1, 2. Therefore, from a GPU user’s perspective, we would like to verify the hardware design in use. Because even though the existing hardware cannot be changed, custom kernels can still be written in a proper way to preserve highest achievable accuracy when needed. For hardware designers, it is equally important to have a convenient and efficient way to quantify the impact of this design choice.
Before we dive into details, we need to understand the role of an “accumulator” and the reason for employing reduced precision. Let’s first consider a hypothetical compute engine that can handle a FP8 matmul of block sizes (3, 4) and (4, 3), as illustrated in Fig. 1a. Zooming into the compute engine, the most basic operation would be a row-column inner product, i.e.
cᵢⱼ = ∑ₖ aᵢₖ * bₖⱼ. One can imagine that an efficient hardware design will simply implement 4 multipliers to compute each pair of aik, bkj, followed by 3 adders to sum up the intermediate results, as shown in Fig. 1b. In this simple example, we can see that the multiplication part can be done in one single parallelized “compute step” assuming enough multipliers are available. But the addition part requires 2 compute steps to complete, as it needs to be done in a hierarchical, serial way. If we scale up this unit design for N elements, multiplication will still take only one step while addition will take log(N) steps.

Furthermore, each multiplier only needs to compute FP8 * FP8 (e4m3), which involves a 4-bit + 4-bit addition (for exponent) and a 4-bit x 4-bit multiplication (for mantissa). However, since each partial product needs to be aligned correctly, the subsequent adders must use significantly more bits than the multipliers. As illustrated by Fig. 2 (just an example, not a real FP8 case), adding two limited precision FP numbers with only 4 mantissa bits could end up as a FP number that requires much more mantissa bits. This loosely explains why the circuit complexity and cost (silicon area and power) of a floating point multiply-accumulate (MAC) operation has a strong dependency on the accumulation precision. Therefore, even if it is safer to use FP32 as the accumulation precision (Fig. 2b), it is worthwhile to explore opportunities to use reduced accumulation precision.

With these examples in mind, the benefits of using reduced‑precision adders in matmul engines become clear.
How to Verify Accumulator Precision? (Using TensorCore as an Example)
Given that a matmul accumulator could be designed with fewer than 23 mantissa bits, the actual output is effectively e8mNacc (where Nacc < 23) with trailing 0s padded up to e8m23. In other words, the output of FP8 TensorCore may look like FP32, but anything smaller than e8mNacc were never calculated during the computation. In this blog, we will demonstrate a simple approach to investigate the accumulator precision using triton kernel.
Assuming the TensorCore output has only Nacc effective mantissa bits (as in e8mNacc), i.e., the last 23 − Nacc bits are 0 already, if we apply a mask to truncate the last Ntrun bits of the TensorCore output, as long as Ntrun ≤ 23 − Nacc, the final matmul results should remain unchanged. Furthermore, by sweeping Ntrun and comparing the matmul output to a reference (i.e., Ntrun = 0), we can infer the accumulator precision of the FP matmul unit under investigation. Here, “truncation of Ntrun bits” refers to zeroing out the last Ntrun bits of a floating point number, which are the least-significant bits (LSBs) of the mantissa.
Why Triton?
We use triton language because it allows the proposed method to generalize to other accelerators that support Triton. It also greatly speeds up development for this experiment due to its simplicity and the right level of accelerator control. Although Triton is expected to evolve over time, because our implementation is based on Triton’s matmul tutorial, we anticipate the specific code requiring future rewrites will be minimal.
Experiments
A runnable code is provided at the end of this notebook. Here, we adopted a triton matmul kernel from triton tutorial and added a simple truncation function. Since a great amount of details can be found in the original tutorial, we will only highlight the truncation related modifications we made. Roughly speaking, matmul(A, B) is decomposed into smaller blocks and processed in parallel. Each block of A and B has shapes (BLOCK_SIZE_M, BLOCK_SIZE_K) and (BLOCK_SIZE_K, BLOCK_SIZE_N), respectively. The block-level matmul is computed by Triton’s tl.dot() function, producing a temporary tensor accumulator_inner of shape (BLOCK_SIZE_M, BLOCK_SIZE_N), which assumed to have only Nacc effective mantissa bits.
- Truncation of
accumulator_inner: We truncated the last Ntrun bits ofaccumulator_innerusing a bit operation with a pre-defined mask. For simplicity, we ignore rounding by settinground_bit= 0.
def prep_round_and_trun_mask(trun_bits): round_bit = 1 << (trun_bits - 1) if trun_bits > 0 else 0 trun_mask = ~tl.cast((1 << trun_bits) - 1, tl.uint32) return round_bit, trun_mask def round_and_trun(x, round_bit, trun_mask): """Round and truncate (usually for accumulator).""" return libdevice.uint_as_float( (libdevice.float_as_uint(x) + round_bit) & trun_mask )
2. Accumulation across the K-dimension: Each truncated accumulator_inner was further accumulated into a pre-allocated FP32 tensor accumulator while stepping through K-dimension. The shape of accumulator is the same as accumulator_inner.
3. Writing the results back: After iterating through the K-dimension, the final accumulator values are written back to the corresponding block in target output tensor C, whose shape is (M, N).
Results and discussions
From both Table 1 and Fig. 3 below, we observed that truncating up to 10 least significant mantissa bits of the output (using H100 FP8 TensorCore) produces exactly the same results as the case with no truncation. This indicates that those bits were already 0 in the original output. The experiment therefore suggests that the accumulator is implemented using a special FP22 format (e8m13) for compute efficiency reasons. We repeated this same experiment on an RTX4000-series GPU (Ada Lovelace architecture) and observed the same behavior.

One important consideration we should keep in mind is that this experiment relies on the Triton compiler to translate Triton codes into equivalent CUDA codes. We must ensure that the TensorCore performing the task is indeed the one we intended to inspect, i.e., FP8. In rare situations, the Triton compiler may choose to use FP16 TensorCore instructions for certain FP8 computations. The most reliable way to confirm the actual hardware instructions executed is to use the NVIDIA profiler ncu(3, which is included in cudatoolkit) to inspect the underlying CUDA instructions associated with the Triton tl.dot call.
Readers can save this notebook as a python file and then launch ncu using the following command-line invocation.
/usr/local/cuda-13.0/bin/ncu --target-processes all --set full --import-source yes -f --kernel-name matmul_kernel --launch-skip 3 --launch-count 1 -o ./tl_fp8mm_backend_H100 python accumulator_precision_test.py |
From ncu profiler readout shown below, we found that FP8xFP8 tl.dot() for the chosen block size (MxNxK=64x64x32) was translated into a QGMMA instruction — an FP8-TensorCore-specific instruction. This confirms that the FP8 TensorCore was indeed used.

As mentioned earlier, the Triton compiler can sometimes choose a different implementation for tl.dot. For example, if we set num_warps = 2 in kernel_config dictionary and repeat the experiment, Triton will pack FP8 into FP16 and use HMMA to perform the computation, where HMMA is a FP16-TensorCore-specific instruction. In this case, the corresponding results show that the accumulator of FP16 TensorCore is only 1 bit shorter than FP32.

Furthermore, since a specialized matmul unit is designed to handle inputs of certain fixed sizes, if BLOCK_SIZE we choose exceeds what TensorCore can handle, the compiler or CUDA library will automatically decompose the operation into several smaller operations. In our triton code, we can increase the BLOCK_SIZE K to 128 and verify with ncu again. We will see that each WGMMA instruction is only capable of dealing with K=32, which means there is an additional summation involved to combine the partial results from multiple TensorCore calls. A natural question is: What precision is used for this intermediate summation? This is the same FP alignment and precision loss problem that we have been discussing. Based on the output from K=128 experiment, we still observe 13 effective mantissa bits. This provides an important insight: if we choose block sizes for the triton kernel that exceed TensorCore’s base design, whether for performance reasons or due to autotuning, there can be additional precision loss from reduced precision summation. Therefore, if matmul precision is a critical concern (especially when training and backward propagation is involved), before falling back to FP16, we should first try to use an intermediate FP32 accumulation as we did in the triton codes. We demonstrated the BLOCK_SIZE_K effect on accuracy here but readers should keep in mind that smaller blocks will impact kernel performance. Readers may want to start from a larger block size, e.g. if autotuning suggests 256 or 512, then gradually reduce to 128, as used in 1, and consider the trade-off between using FP16 and decreasing block size. Alternatively, if using cuBLAS in the custom kernel, CUBLASLT_MATMUL_DESC_FAST_ACCUM flag can achieve the same effect of accumulation precision promotion. 4
Finally, the concept of a reduced-precision accumulator can also be applied to INT8xINT8 engines. The main difference between FP8 and INT8 matmul is that INT8 accumulator truncation occurs on the most significant bits (MSBs) rather than the least significant bits (LSBs). In other words, we need to consider overflow problem instead of underflow as in FP8. Simple modifications to the provided Triton kernel can be made to investigate INT8 behaviors. We leave this exercise to readers who are interested.
Conclusion
We explained the importance of using reduced precision in the accumulator of a matmul engine and demonstrated a simple method to verify the design of our existing accelerator. Understanding of accumulator precision is crucial for users with accuracy sensitive applications who write custom kernels, as well as for hardware designers who need to emulate this behavior for their next generation designs. More importantly, this triton-kernel-based approach can be seamlessly combined with the PyTorch ecosystem, which means the same technique can be extended to other existing and future accelerators that support the Triton language, significantly reducing development time.
Reference
- DeepSeek-V3 Technical Report, Section 3.3.2 Increasing Accumulation Precision. https://arxiv.org/html/2412.19437v1.
- SageAttention2, Introduction/Challenge/C2. https://arxiv.org/html/2411.10958v7
- ncu website https://docs.nvidia.com/nsight-compute/index.html
- https://docs.nvidia.com/cuda/cublas/
Runnable code can be found here
https://gist.github.com/chichun-charlie-liu/88a99949fcbe589aa5f71e48616ac944
Building Highly Efficient Inference System for Recommenders Using PyTorch
5 Feb 2026, 6:00 pmWhy Choose PyTorch for Recommendation System
PyTorch has emerged as the de facto framework in the AI community, with the majority of cutting-edge research, especially in areas like recommendation systems, retrieval, and ranking, being conducted with PyTorch. Developers are eager to bring the latest model advancements into production as quickly as possible. A PyTorch-based recommendation inference system is well-suited to this need, enabling both (1) high efficiency and (2) rapid model adoption in production environments.
In this blog, we will discuss the design of a high-performance recommendation inference system built with PyTorch. Approaches based on these design principles have been thoroughly validated and have successfully served extremely high volumes of traffic, demonstrating strong efficiency and reliability. Our PyTorch-based recommendation inference system serves as the backbone for Meta’s most critical machine learning workloads. Powering global surfaces, including Feed, Ads, Instagram, Reels, Stories, and Marketplace, the system manages a diverse array of ML architectures. These range from sophisticated extensions of the foundational Deep Learning Recommendation Model (DLRM) to cutting-edge, novel modeling techniques such as DHEN (Deep Hierarchical Ensemble Network), HSTU (Hierarchical Sequential Transducer Unit), Wukong, and more.

A Typical Recommendation Research to Production Inference Workflow
The Overall Workflow
After training, a model definition and its trained weights are delivered for inference, establishing a clear contract between the training and inference stages. However, running a training model directly in a production inference environment is highly inefficient and does not meet the performance requirements of real-world applications.
To address this, we need to rapidly and reliably ship trained models to production, while also supporting frequent updates as models are improved or retrained. This dynamic environment—with many models and many versions—demands a robust transformation pipeline that converts trained models into optimized inference models. Such a pipeline ensures that the resulting inference model files are tailored for efficient hardware utilization, enabling high throughput (QPS, i.e., queries per second) and meeting strict latency requirements. In summary, a dedicated system for transforming training models into production-ready inference models is essential for maintaining agility, scalability, and performance in our model deployment process.

Trained Model to Production Inference Transformation Flow
Defining the Inference Model and Weights Mapping
The trained model often includes components that are only necessary during training, such as loss functions and certain regularization techniques. It is best practice to define a dedicated inference model that mirrors the forward logic of the training model, while also allowing for inference-only optimizations. Additionally, a mapping between the inference model’s parameters and the trained model weights (checkpoint) must be established, especially if fully qualified parameter names differ between training and inference. This mapping should be maintained and updated throughout the inference model preparation process.
Capturing the Computation Graph from Python Models
To enable efficient inference, a series of model transformations must be applied to the inference model. Applying these optimizations requires converting PyTorch models defined in Python into a graph representation. Capturing a PyTorch model’s computation graph is a challenging task. Using torch.fx to extract an FX graph is a common practice. This method assumes that the model architecture does not contain cyclic structures. For submodules with complex control flows, these can be marked as leaf nodes to simplify the graph extraction process.
Recently, torch.export has become a more mature tool for capturing computation graphs, offering improved support for models with control flow. However, the resulting PT2IR (a specialized FX graph) can be quite low-level, and decomposed, which may complicate certain model transformations.
Model Transformation and Optimization
After capturing the FX graph, a variety of optimizations can be applied through model transformation passes. Below are some common transformation passes:
- Model Splitting: For distributed inference scenarios, it is often necessary to split the full “forward” graph into smaller subgraphs. Each subgraph represents the forward pass of a submodule, enabling distributed execution across multiple devices or hosts. Additionally, these transformations can group similar computations together, further enhancing overall efficiency.
- Operator Fusion: Multiple operations can be replaced with a single, fused implementation to improve efficiency. This can be achieved by swapping submodules or applying graph-level transformations.
- Quantization: Similar to operator fusion, certain layers (e.g. linear layers) can be replaced with quantized versions to reduce memory usage and improve inference speed. TorchAO provides the support for linear quantization with PT2 support.
- Compilation (a.k.a. Lowering): Model compilation techniques are typically applied ahead of time as part of the transformation process. This step converts model code into lower-level representations that are better suited for the target inference devices. (See the AI Compiler section below for more details.)

Graph Transformation Example: Full Forward Graph to Split Graph
Model Serialization
Standard PyTorch models use the pickle format for storage, but this approach is insufficient for production due to weak backward compatibility and Python dependency issues. To address these challenges, several serialization solutions are available:
| Solution | Description | Pros | Cons |
| TorchScript | Capture TorchScript IR through scripting or tracing, and save as TorchScript format. | 1) Mature and strong backward compatibility support
2) Solid control flow support |
1) Some constraints on model definition (e.g., no complex data structures)
2) Deprecated and not supported |
| torch.export | Export the PyTorch model as PT2IR. | 1) The official way to serialize models in PT2
2) Active development |
1) Control flow may need additional handling |
| torch.package | Directly export related Python modules as source code and pickle objects. | 1) Great flexibility | 1) May require manual effort to define module boundaries
2) Requires Python dependency |
Regardless of the serialization format, the resulting artifact should be a zip file. This allows for easy inspection and debugging by unzipping the contents. Processed weights can also be packaged within the zip file. We are prioritizing torch.export for new model development over older tools like TorchScript and torch.package. With TorchScript now being deprecated, torch.export provides a more robust path forward with active feature development, while also providing necessary superior performance compared to torch.package by allowing for a Python-independent runtime.
Model Loading and Execution
Once the inference models are prepared, you will have a set of inference model files. For extremely large models, it may be necessary to load the model structure and weights separately, which could require custom logic for saving and loading.
After loading the model files, the runtime begins processing inference requests. Since PyTorch does not natively provide serving capabilities beyond model execution, an additional server layer is required to manage inference serving. Below, we outline the key features of an efficient and scalable PyTorch inference server for recommendation systems:
Lightweight PyTorch Executor Wrapper
- The server converts requests to PyTorch model inputs. This wrapper should be minimal to ensure efficiency.
Efficient and Flexible API
- In a distributed environment, different components of the model communicate via APIs, which necessitates precise semantic definitions, such as specifying the batch dimension and other relevant parameters.
- Tensor-based APIs align well with the PyTorch model’s forward method.
- Zero-copy (in-place) APIs allow us to update models in-place, efficiently and seamlessly transitioning from serving one version of a model to the next without requiring significant additional capacity to load both model versions during the transition.
DAG Representation and Executor
- Modules with similar characteristics (e.g., all embedding bags) can be grouped into dedicated submodules for batch execution.
- After model splitting, the original forward function is represented as a Directed Acyclic Graph (DAG), with each node corresponding to a submodule. An executor is required to manage the execution of this DAG.
- DAG nodes may be deployed across multiple hosts, which necessitates support for remote execution. In such cases, an efficient communication library is essential to ensure seamless and performant interactions between distributed components.
Optimizations
In the previous section, we outlined the core principles for building a robust, efficient, and scalable recommendation inference system with PyTorch, one that can handle high traffic volumes and meet stringent production requirements. To further enhance system performance, we will now discuss several key optimization strategies below.
GPU (Accelerator) Inference
With the emergence of new model architectures, computational demands have increased significantly. CPUs often struggle to meet the latency requirements for running such models online, making accelerators like GPUs a natural choice. However, running the entire model on a single GPU can be inefficient, and models may not fit within the memory constraints of a single device. Therefore, splitting models into multiple segments and executing the most compute-intensive layers on GPUs is a practical approach.
Additionally, GPU kernel launch overhead can be substantial. To mitigate this, batching requests together reduces the number of kernel launches and improves overall throughput.
C++ Runtime
While the most straightforward way to run PyTorch models is via Python, the Python runtime introduces noticeable overhead, especially as QPS (queries per second) increases. Typically, Python overhead becomes significant at QPS ≥ 100, and can become a severe bottleneck at QPS ≥ 1000.
For high-QPS scenarios (≥ 100 per host), we recommend using a C++ (or Rust) runtime. Both TorchScript (for TorchScript models) and ExecuTorch (for models saved with torch.export) provide C++ runtimes. Recently, development has focused on a new runtime, torch.nativert, designed for executing torch.export models across servers, as an alternative to the TorchScript runtime, which has been deprecated as of the last PyTorch Conference.
Distributed Inference (DI)
Running the entire inference model as a monolith can be inefficient or even infeasible. Instead, splitting the model into multiple components and distributing them across different workers can both improve efficiency and enable scaling to larger model sizes. Common DI patterns include:
- CPU-GPU DI: Assign input processing and lightweight computations to CPUs, while offloading compute-heavy layers of the model to GPUs.
- Embedding-Dense DI: Group embedding tables into dedicated submodules that can be served on separate hosts (similar to traditional parameter servers). Dense layers, which are smaller but compute-intensive, can be grouped and executed together for improved efficiency.
- Dense Model Parallelism: Split a single dense network into multiple sub-networks that can be executed in parallel, either on different CUDA streams within the same device or across multiple devices, enabling selective lowering and parallel execution.
AI Compiler and High-Performance Kernel Libraries
To achieve maximum performance, developers may be tempted to rewrite model definitions in C++/CUDA and run them directly. However, this approach does not scale well. Instead, AI compilers can automate this process, generating highly optimized artifacts. Options include:
- AOTInductor (torch.compile)
- AITemplate
- TensorRT
These compilers generate new, compiled artifacts that are packaged alongside the serialized model. For production RecSys deployments, C++ runtimes are preferred for performance reasons. This precludes the use of Python-dependent JIT workflows like torch.compile; instead, Ahead-of-Time (AOT) Inductor is used to precompile models into static runtime artifacts deployable in C++.
AI compilers utilize high-performance kernel libraries to maximize computational efficiency on various hardware platforms, including:
- CUTLASS/CuTeDSL for Nvidia GPUs
- Composable Kernels / AITER for AMD GPUs
- ZenDNN for AMD CPUs
- OneDNN for Intel CPUs
Request Coalescing
To maximize efficiency, requests should be coalesced (batched) together. This requires understanding the semantics of each input, particularly which dimension represents the dynamic batch size, so that requests can be concatenated appropriately. The model’s forward method should be tagged with batch information to facilitate coalescing, and the runtime must support this feature.
Table Batched Embedding
Querying embedding tables in PyTorch can incur significant operator kernel launch overhead, especially when dealing with tens, hundreds, or even thousands of tables. Since embedding lookups are data-transfer-heavy (akin to hash map queries), batching embedding bags and querying all tables in a single call can greatly reduce overhead and improve data transfer efficiency.
Quantization
Both embedding and dense layers of the model can benefit significantly from quantization:
- Embeddings: Data types like bf16 and int8 are generally safe, and int4 is often acceptable. Different tables and rows may have varying numerical sensitivities. PyTorch supports per-table quantization, even for table-batched embeddings, allowing developers to customize quantization strategies. Some tables may even use int1 or int2 configurations.
- Dense Layers: Dense layers are more sensitive to quantization. Typically, fp16 and bf16 are acceptable for entire dense submodules, but exceptions exist, such as fp16 may lack sufficient range, and bf16 may not provide enough accuracy. For further efficiency, fp8 and fp4 can be applied at the layer level, though this often requires manual tuning.
All quantization strategies should be validated through accuracy evaluation. TorchAO provides support for Linear and Conv layers, good to start with.
Delta Update
Model freshness is critical for serving recommendation models. As models grow larger, loading the entire model becomes increasingly expensive. A balanced approach is to apply partial weight updates (delta updates). While implementing a protocol for data transfer is straightforward, tuning the weight loading pace is crucial to avoid disrupting serving. Embedding tables are generally more tolerant of partial updates, while dense modules are more sensitive. For dense modules, we recommend using a buffer module to support full module swaps, rather than updating weights individually.
Developer Experience
Python Runtime
To streamline the development and debugging of the inference flow, we recommend providing a lightweight Python runtime environment (versus using the C++ runtime). This approach allows developers to efficiently determine whether issues originate from the runtime or the model itself. Additionally, it simplifies the process of adding instrumentation for debugging purposes.
With the introduction of free-threaded Python, both runtime and communication overhead can be further minimized within the Python ecosystem. This advancement also makes deploying Python runtimes in production environments increasingly practical.
Module Swap-Based Transformations
Historically, graph-based transformations have been challenging for model authors to understand and debug, largely due to the complexity of graph manipulations and the loss of original stack trace information. To address these issues, we recommend shifting such optimizations earlier in the inference module authoring process. By adopting a holistic, native PyTorch module-based workflow, and leveraging eager mode transformations, we have found that the inference development experience is significantly improved.
Eval Flow
To ensure both model and runtime quality, we recommend implementing the following two evaluation flows:
- Accuracy Verification: Compare the inference model’s quality against training evaluation results.
- Performance Benchmarking: Replay production-like traffic to assess throughput and latency.
Conclusion
At Meta, we developed a highly efficient recommendation inference system built on PyTorch that is critical for translating cutting-edge research into production-grade services. This blog detailed a robust workflow, starting from a trained model definition and its weights, progressing through essential inference transformation steps, including graph capture, model splitting, optimizations (fusion, quantization, compilation, etc.), and finally serialization. We then outlined the requirements for a high-performance inference server, emphasizing a lightweight executor, flexible tensor-based APIs, and a DAG-based model execution model. Finally, we explored advanced optimization techniques crucial for high-QPS, low-latency performance, such as leveraging GPU/Accelerator inference, adopting a C++ runtime, implementing Distributed Inference patterns, utilizing AI compilers, and applying sophisticated methods like request coalescing, Table Batched Embeddings, and quantization. By adhering to these principles and utilizing the featured open-source libraries, developers can build scalable, performant, and agile PyTorch-based systems capable of serving the world’s most demanding ML recommendation workloads.
Related Libraries
TorchRec: A PyTorch domain library that powers Meta’s production recommender systems by providing the sparsity and parallelism primitives necessary to train and deploy models with massive embedding tables sharded across multiple GPUs.
TorchAO: TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box with torch.compile() and FSDP2 across most HuggingFace PyTorch models.
AITemplate: An open-source Python framework that transforms deep neural networks into highly optimized C++ code for NVIDIA and AMD GPUs, delivering near-roofline inference performance through unified hardware support and comprehensive operator fusion.
TensorRT: NVIDIA TensorRT is a developer ecosystem comprising inference compilers, runtimes, and model optimizations designed to deliver high-performance, low-latency deep learning inference for production applications.
Generative Recommenders / HSTU: A library reformulates classical recommendation systems as generative models and introduces algorithms like HSTU and M-FALCON to drastically accelerate training and inference while establishing scaling laws for billion-user scale environments.
FBGEMM: Highly-optimized kernels used across deep learning applications, including recommendation systems.
Triton and Low-Level Extension (TLX): Triton is a Python-based language and compiler designed for writing highly efficient GPU kernels. TLX (Triton Low-Level Extensions) is an experimental add-on that provides fine-grained, hardware-specific control within Triton, enabling developers to further optimize performance on modern GPUs.
oneDNN: oneAPI Deep Neural Network Library is an open-source, cross-platform performance library of basic building blocks for deep learning applications, specifically optimized for Intel processors.
ZenDNN: ZenDNN (Zen Deep Neural Network) Library accelerates deep learning inference applications on AMD CPUs.
CUTLASS / CuTeDSL: CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. CuteDSL is a Python-based embedded domain-specific language for Cutlass.
AITER: AITER is AMD’s centralized repository that supports various high performance AI operators for AI workloads acceleration, where a good unified place for all the customer operator-level requests, which can match different customers’ needs.
CK: The Composable Kernel (CK) library provides a programming model for writing performance-critical kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library uses general purpose kernel languages, such as HIP C++.
Portable Paged Attention in Helion
3 Feb 2026, 5:32 pmRecently, the PyTorch team released Helion, a new domain-specific and PyTorch-based language to make the development of high-performing but portable kernels easier. With extensive autotuning built in, Helion has the promise to move the forefront of performance portability further than Triton.
To test this promise (and learn Helion), we embarked on the challenge to write one of AI’s most performance-critical kernels in Helion: Paged Attention, the core of vLLM.
In the past year, we contributed a performance and platform portable attention backend for vLLM written entirely in Triton, which has no external dependencies and runs on NVIDIA, AMD, and Intel GPUs (see our PyTorch conference talk). Hence, we implemented one of the kernels (unified_attention_2d) in Helion as a new experimental backend in vLLM (PR#27293).
Brief Background to vLLM, Triton, and Helion
vLLM is widely used for LLM inference and part of the PyTorch Foundation. vLLM is increasingly being adopted in production and can be executed on NVIDIA, AMD, and Intel GPUs, as well as custom accelerators like Google’s TPU, Huawei’s Ascend NPU, AWS Inferentia, or IBM Spyre. vLLM features efficient and high-performance inference for nearly all LLM models, which is achieved by its well-designed software architecture and deep integration with torch.compile.
Triton is a domain-specific language (DSL) that can be written in Python and offers just-in-time (JIT) compilation to AMD, Intel, and NVIDIA GPUs. Triton kernels have shown to demonstrate state-of-the-art performance and can be portable. For example, we have written paged attention in Triton and the very same kernel code achieves state-of-the-art performance on NVIDIA H100 and AMD MI300 (you can read our extensive paper or the related blog post). For this we also leveraged Triton’s autotuner in a limited way. However, autotuning in Triton has severe limitations that prohibit its use in production, despite its positive impact on performance portability. Hence, for our Triton attention backend, we use simple if-else statements as heuristics for now.
Besides this, Triton is also the output language of PyTorch Inductor, the compile component of torch.compile.
Helion is yet another DSL, which became beta at the end of October. Helion considers itself as “tiled PyTorch” and has broadly two aims: First, to bring tiling to PyTorch so that tiled programs can be written using PyTorch APIs. And second, enhance portability by extensive autotuning. In contrast to Triton, Helion’s autotuner has not only a usable caching mechanism, but the autotuner also has a lot more degrees of freedom. This larger freedom comes from the fact that in Helion, the autotuner can also change algorithmic aspects of an implementation, in addition to lower-level compile flags like the number of warps or pipeline depths. It also features advanced search algorithms, which is something we previously investigated in the context of Triton.
Implementation Details: How to write Paged Attention in Tiled PyTorch
Launch Grid and approach of parallelization
As a starting point, we wanted to re-implement the simpler “2D” version of our unified attention Triton kernels. It is called “2D” because this kernel has a two-dimensional launch grid (see details here), and we selected this kernel version, since we thought the parallel tiled softmax implementation would be too complex in the beginning.
However, since launch grids are handled differently in Helion than Triton, we did not follow the 2D approach 1:1, but built a Helion kernel around the core concept of “Q blocks”. This concept is illustrated in the following Figure:

Figure 1: Concept of “Q blocks” in our Helion kernel.
In this Figure, we see the three dimensions of one request that need to be computed. An attention kernel needs to iterate over all the query tokens up to the query_length (bottom axis). In our kernel, we fetch multiple query tokens at the same time. This tile size, TILE_Q, is tuneable. Next, for each token, there are multiple query heads and KV-heads (left axis). We have re-implemented our QGA optimization so that all query heads for one KV head are fetched at once. The query-heads-per-kv-head (QpKV) is the tile size in this direction and is called TILE_M. Finally, we have to iterate over the KV cache for this query up to the current context length in tuneable blocks of size TILE_N (diagonal axis). In this inner loop, the actual attention computation, including matrix multiplications (hl.dot), is happening, using an online softmax implementation. In the kernel, there is an additional loop around all of this to iterate over all requests in a batch (not in the Figure).
However, the input tensor, as it is handled by vLLM, has as first dimension the number of sequences and the query length combined (which is often called “flattened varlen” layout). Consequently, vLLM provides an extra tensor that is used as an index to know which token belongs to which sequence.
Hence, after experimenting with some implementations, we settled on the four-loop approach described above:
# registering tunable block sizes q_block_size = hl.register_block_size(1, q_block_padded_size) num_pages_at_once = hl.register_block_size(1, 32) # outer loop -> becomes the launch grid for seq_tile, tile_m in hl.tile([num_seqs, num_query_heads], block_size=[1, num_queries_per_kv],): seq_len = t_seq_lens[seq_tile] query_start = t_query_start_lens[seq_tile] query_end = t_query_start_lens[seq_tile + 1] query_len = query_end - query_start # loop over the query of one request for tile_q in hl.tile(query_len, block_size=q_block_size): ... # loop over KV cache for tile_n in hl.tile(num_blocks, block_size=num_pages_at_once): ...
As can be seen, the outer loop is a fused loop and has two dimensions: The sequences in a batch and the QpKV. This outer loop will become the launch grid in Triton (Helion recommends to use hl.tile over hl.grid, also for the outer loop). Since we need the tuned block sizes also for e.g. boundary computation before the loop, we register the block sizes explicitly before. Additionally, to make the launch grid simpler, we changed the order of loops in the implementation vs. the description above and let the outer loop iterate over the query heads.
Next, the second loop is then over the query length of the selected sequence with a tuneable tile size. But please note that we pad the upper bound of this tile size (see q_block_padded_size), so that neither the JIT compiler nor the autotuner are triggered for all possible combinations of query lengths. Instead, we provide here only padded length to the power of two, which reduces the JIT/autotune overhead at runtime. The innermost loop is over the number of KV cache pages in the selected sequence. Hence, also the upper bound of the corresponding registered block size means 32 pages of KV cache memory (each of e.g. 16 tokens).
The Triton code generated by this could look like:
# src[helion_unified_attention.py:129]: for seq_tile, tile_m in hl.tile( # src[helion_unified_attention.py:130]: [num_seqs, num_query_heads], # src[helion_unified_attention.py:131]: block_size=[1, num_queries_per_kv], # src[helion_unified_attention.py:129-132]: ... num_pid_m = num_seqs num_pid_n = tl.cdiv(32, _BLOCK_SIZE_3) inner_2d_pid = tl.program_id(0) num_pid_in_group = 4 * num_pid_n group_id = inner_2d_pid // num_pid_in_group first_pid_m = group_id * 4 group_size_m = min(num_pid_m - first_pid_m, 4) pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m pid_1 = inner_2d_pid % num_pid_in_group // group_size_m offset_2 = pid_0 offset_3 = pid_1 * _BLOCK_SIZE_3 ... # src[helion_unified_attention.py:141]: for tile_q in hl.tile(query_len, block_size=q_block_size): # src[helion_unified_attention.py:141-252]: ... for offset_9 in tl.range(0, v_0.to(tl.int64), _BLOCK_SIZE_0, loop_unroll_factor=2, num_stages=2, disallow_acc_multi_buffer=False, flatten=False): indices_9 = offset_9 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int64) # src[helion_unified_attention.py:174]: for tile_n in hl.tile(num_blocks, block_size=num_pages_at_once): # src[helion_unified_attention.py:174-244]: ... for offset_10 in tl.range(0, v_19.to(tl.int64), _BLOCK_SIZE_1, loop_unroll_factor=1, num_stages=1, disallow_acc_multi_buffer=True): indices_10 = offset_10 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int64) mask_1 = indices_10 < v_19
As can be seen, this uses the pid_type=”flat” version of program launch, as determined by the Helion autotuner. In this program type, the kernel has only one “real” PID (tl.program_id(0)) and derives all other “local” ids from this.
Tiling in Helion
In general, Helion requires that tiles need to be generic, i.e. we can’t assume that they always will correspond to the block size of vLLMs KV cache and hence need to write our program accordingly.
Tiling in Helion is quite powerful and the generated tiles by hl.tile are automatically adjusted to incomplete tiles. For example, if we have a tile size of 32 and a tensor shape of 63, the second tile would have only 31 elements and all masks are generated automatically.
However, in paged attention there are additional constraints: We always have to load full pages from the KV cache, since we cannot determine at compile-time if we need the complete page or not. This is not a problem for Helion, but it meant that we needed to create our own masks.
The dynamism of tiles also created another problem for us: Imagine a batch that has queries of length 7, 2, 1 as shown in the figure below:

Figure 2: Accessing the flattened “varlen” tensor with fixed tile sizes requires manual masking.
So, if we would loop through this with always the same tile size (4 in this example), we would have tokens of the second request in the second tile, together with the last three of the first request! Also, the next tile would mix tokens from the second and third request. Yet, we cannot change the tile size for every sequence in a batch, since each compiled kernel can have only one tile size (remember, it is a compile time constant). One solution would have been to use only block sizes of 1 here, but as we know from our development of the Triton Attention Backend in vLLM, this performs very badly in general.
Hence, the only other option was that we adjust the index of the tile and apply manual masking using hl.load.
adjusted_tile_q_index = query_start + tile_q.begin + hl.arange(q_block_size) query_head_offset = tile_m.begin + hl.arange(num_queries_per_kv) q_load_mask = adjusted_tile_q_index[:, None, None] < query_end # (tile_q, tile_m, HEAD_SIZE) q = hl.load(t_query, [adjusted_tile_q_index, query_head_offset, hl.arange(head_size)], extra_mask=q_load_mask)
Overall, this kernel implementation in Helion requires 133 lines of code (vLLM formatting with comments) vs. 295 in Triton. Check it out here.
For us, writing the Helion kernel was very straightforward compared to writing the Triton version, despite the algorithmic changes we had to make due to Helion’s programming model. In Triton, a lot of time and effort (and lines of code) are spent to make sure all tiles have the right masks and boundaries, which requires manually taking care of tensor strides and offsets in all dimensions. Helion does this automatically and also correctly, which saves a lot of effort for the developer (esp. debugging!).
In addition, Helion handles other low-level details like actual launch grid implementation, tile size discovery, or tensor memory allocation via autotuning.
Autotuning
One of Helion’s strengths and core features is autotuning. Not only can it search across a variety of different tuning knobs, it also detects all possible valid values of these knobs by itself. The user just needs to define the lower and upper bounds of tiles or block sizes. This could also be done like:
for tile_n in hl.tile(seqlen//page_size, block_size=None):
This is in strong contrast to Triton, where a user needs to list all possible valid configurations (which is usually done with a lot of nested num_stages=s for s in [1, 2, 3, 4, 5, 6, 8, 42] …). In addition to the less-comfortable API, this also risks that a user easily misses a combination that would prove quite powerful on one platform. Helion’s approach solves this problem by requiring the user to define the shapes only on a symbolic level and then deriving all of the possible combinations.
If it is not the outer-most hl.tile loop, which will become the launch grid, the boundaries of the tiles can be derived from the tensor shapes.
Helion also checks the correctness of each kernel variant that is created during autotuning, compared to a default configuration, and discards all experiments where numerical errors are too high. In the beginning of our experiments, this baseline (or called “default configuration” by Helion) was derived automatically as somewhere “in the middle” of the discovered search space. However, this auto-discovery created some problems for us, since the result was an invalid configuration for our kernel, which would not work with the page size of 16 in vLLM. Hence, autotuning didn’t work at all and we had to patch Helion and define a default configuration manually.
Yes, Helion is still only beta and in active development and this issue was later fixed by allowing the user to define an external function as autotune baseline: autotune_baseline_fn=callable(). This feature solved our issue and we could then define our existing Triton implementation as baseline, from which we know it would give a very good performance and correct results. This greatly simplified and enhanced the autotuning process for us.
Another feature we really appreciate is the “effort level” of autotuning, which Helion added as a user-defined setting: autotune_effort=, which can be ‘none’, ‘quick’, or ‘full’. From our experience in developing the Triton attention backend in vLLM, we know that due to the paged memory and the resulting limited number of valid configurations, the autotuning process is constrained and it usually does not pay off to do days of autotuning. Hence, once this feature was available, we set the effort to quick. But as usual, terms like “quick” and “slow” are relative, and the quick mode of Helions autotuner still required 10 hours to tune our kernel for 72 different scenarios (like batch size, sequence lengths, head size, etc.). Which was faster than the 25 hours it takes with the “full” (or default) setting, but maybe not as “quick” as we would have wished.
Once the autotuner finishes, it also prints out the recommended best configuration:
[2752s] Autotuning complete in 2752.1s after searching 5353 configs. One can hardcode the best config and skip autotuning with: @helion.kernel(config=helion.Config( block_sizes=[32, 4], indexing=['pointer', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], l2_groupings=[2], load_eviction_policies=['', '', '', '', '', 'last', 'last', ''], loop_orders=[[1, 2, 0], [1,0]], num_stages=6, num_warps=8, pid_type='flat', range_flattens=[None, True, True, True], range_multi_buffers=[None, None, None, False], range_num_stages=[], range_unroll_factors=[0, 1, 2, 1], range_warp_specializes=[]), static_shapes=False)
This gives another impression about all the different configurations knobs Helion can explore. Therefore, it is important and consequential that Helion invests a lot in different autotuning algorithms. Currently, genetic algorithms are used as the default.
In the example shown above, Helion finds the tile sizes of 32 tokens (TILE_Q) and 4 pages (TILE_N, equivalent to 64 tokens) as the best. It also figures out how to address all the tensors involved, or if loops should be flattened and reordered.
A detailed discussion of all the knobs is beyond the scope of this blog post.
Our experience with autotuning the Triton kernels taught us that it is a good trade-off to tune for a broad range of scenarios (in a microbenchmark setting) and then select only a handful of configurations to be used in vLLM and use decision trees or other heuristics to select between them.
However, one disadvantage for the current version of Helion is that it expects one configuration for the complete kernel and we cannot differentiate between configurations that would be beneficial for a prefill batch or decode batch, as we can do in Triton. Hence, for the experiments in this blog, we settled on 6 configurations to be tuned “live” at vLLM deployment in the case it runs on NVIDIA, and 7 for AMD GPUs.
Performance Evaluation
Having a good developer experience is of course nice, but we still might not use Helion if the performance would be worse than what we could do with Triton. Hence, we benchmarked our new paged attention in Helion against our established Triton kernel on NVIDIA H100 and AMD MI300X. For use cases like inference servers, we always have to look at two aspects: How the kernel performs individually and how the new kernel affects the complete system, i.e. vLLM. Therefore, we first analyze the kernel performance alone in micro-benchmarks and then, in a second step, perform an end-to-end analysis.
Micro-Benchmarks
To analyze the kernel performance, we used our microbenchmark suite, which we also used for development of the vLLM Triton attention backend. Here, we base our test parameters on the Llama3.1-8B architecture (128 head size, 32 query heads, and 8 KV heads) and vary sequence lengths and batch sizes based on real-world samples. The sequences contained within a batch have variable lengths, i.e. the “varlen” mode, as is the default in vLLM. We used Helion 0.2.4, Triton 3.5.0 and PyTorch 2.9 for all experiments.
For each GPU platform, we made two plots with the same data, but in one plot sorted by the share of decode requests within a batch, and in the other by maximum sequence length within a mixed prefill-decode varlen batch. Please note that all measurements are done with CUDA/HIP graphs enabled, so we do not evaluate any software overhead like compile time or launch time, just pure kernel performance. Since we focus on the comparison between Triton and Helion kernels, we normalized all data with the leftmost Triton result in each plot.
Each plot shows three results: The Triton 2D kernel (as it is in vLLM today) as baseline, the Helion kernel with dynamic shapes (static_shapes=False) with the selected configurations used for small-scale “live” autotuning in vLLM for each platform, and the Helion kernel with static shapes (static_shapes=True) and full autotuning. Of course, to combine dynamic shapes with full autotuning and static shapes with quick autotuning would also be valid combinations, but we don’t plot them here for clarity. Instead, we selected the both “extremes”: First, “fewer” optimizations per request, which means dynamic shapes and quick autotuning over the selected configurations. And second the scenario with the most possible optimizations, which means static shapes and full tuning for each of them.

Figure 3: Microbenchmark on H100. The latencies are sorted by share of decode requests within a batch. The sequences contained within a batch have variable lengths with a median of 40% of the maximum sequence, as it could occur in real-world online inference scenarios. The combined number of tokens in a batch is denoted as x-axis.

Figure 4: Microbenchmark on H100. Same setup and results as the Figure above, but here the latencies are sorted by maximum sequence length within a batch, with prefill, partial prefill, and decodes combined.
As can be seen for an H100, our Helion paged attention implementation is already outperforming the Triton 2D attention kernel for decodes, and is on-par for large batches in prefill. Our data shows that the performance of the Helion kernel ranges between 29% and 137% vs Triton for prefill requests, and between 132% and 153% for decode-only requests. Also, the plots show barely any difference between the Helion variant with static shapes and the one without. This fact is important for end-to-end measurements and will be discussed in the next Section.
The gap in the prefill can be explained by the smaller launch grid of the Helion kernel vs the Triton kernel. As discussed above, the Helion kernel parallelizes only over the query head and batch dimensions, not along the queries itself. This is in contrast to the Triton 2d kernel, where the two dimensions are the batch dimension (as in Helion) and a mix between query-head and query-tokens as second dimension. Hence, optimizing the launch grid of the Helion kernel is another optimization avenue.

Figure 5: Microbenchmark on MI300X. The latencies are sorted by share of decode requests within a batch. The sequences contained within a batch have variable lengths with a median of 40% of the maximum sequence, as it could occur in real-world online inference scenarios. The combined number of tokens in a batch is denoted as x-axis.

Figure 6: Microbenchmark on MI300X. Same setup and results as the Figure above, but here the latencies are sorted by maximum sequence length within a batch, with prefill, partial prefill, and decodes combined.
For MI300X, the results look a little bit different as the gap for prefills is larger. Here, the performance of the Helion kernel vs. the Triton kernel varies between 13% and 75% for prefill requests and between 58% and 107% for decode-only batches.
However, also here the Helion kernel is on-par or outperforms the Triton kernel for decode-only requests, and therefore we can consider our kernel implementation as platform and performance portable.
End-to-End in vLLM
Being a big fan of vLLM, we of course wanted to evaluate how our Helion paged attention algorithm can perform in realistic and relevant end-to-end scenarios.
Therefore, we wrote a “helion_attn” backend, similar to our Triton attention backend and evaluated its performance. We also submitted a draft PR for this.
For evaluation, we used vLLM built-in serving benchmark script with the popular ShareGPT benchmark. We also disabled prefix caching.
VLLM_ATTENTION_BACKEND=EXPERIMENTAL_HELION_ATTN vllm serve \ meta-llama/llama3.1-8b-instruct/ \ --disable-log-requests --no-enable-prefix-caching vllm bench serve --model meta-llama/llama3.1-8b-instruct/ \ --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --ignore_eos
This setup evaluates the performance of the vLLM inference server “end-to-end”, because the performance is measured by the client, which is sending requests to the server as a real user would be. This means that the client also does not assume any knowledge of the state of the inference server, i.e. how many requests are currently running or which “shapes” the requests should have for good kernel performance. In the particular benchmark used in these experiments, the client sends 1000 requests at once and the vLLM server then has to process them as-fast-as possible, which includes scheduling very large batches and especially decode-only batches.
We compared our experimental Helion attention backend with our Triton backend as it is in vLLM today. Both backends use full CUDA/HIP graphs for mixed and decode only batches. One difference is that the Triton attention backend in vLLM today does not do “live” tuning and instead selects between four different configurations with if-else statements. This is in contrast to our proof-of-concept Helion attention backend, which uses the autotuner at runtime to choose from 6 or 7 configurations, depending on the platform. To be fair for both implementations, we always do two warmup benchmark runs to allow the autotuner to run and Helion and Triton JIT compilers to compile most of the relevant kernel versions. Each plot shows three results: The performance of the triton_attn backend as it is in vLLM today as baseline, the performance of our Helion kernel with static shapes, and the performance of our Helion kernel with dynamic shapes. Also in these experiments, we normalized the results with the Triton results.

Figure 7: End-to-end performance measurements using vllm bench with Llama3.1 8B and the ShareGPT dataset on H100.
As can be seen in the Figure, Helion with static shapes achieves only roughly 26% of Triton’s total throughput, while Helion with dynamic shapes achieves 96% of the total token throughput and is also on par in TTFT (Time-to-first-token, i.e. the prefill time) and very close in ITL (Inter-token-latency, i.e. the time to decode one more token).
This experiment highlighted one important reality of inference servers: Request shapes are diverse, plentiful, and not known in advance. Furthermore, the scheduled batch shapes also depend on various other aspects like the order in which requests arrive. Hence, even if the same benchmark was run as warmup, using Helion with static shapes triggers re-compilation for nearly every request, since shapes of the query tensor are rarely exactly the same. Since this is an end-to-end experiment, these compilation times are reflected in the measured latencies and throughput. This way of evaluating the performance of our kernel implementation is different then looking at the raw kernel performance as done in micro-benchmarks, but it reflects the real-world performance that would be experienced by users of vLLM.
Consequently, Helion with static shapes performs poorly due to its huge JIT overhead that outweighs the (small) performance gains in kernel runtime. Please note, due to unrecoverable crashes of the generated Triton code during “live” autotuning, we had to disable it for the end-to-end experiments with static shapes and used the configuration with the best decode performance, as determined in micro-benchmarks. This limitation of the experiments could explain a small part of the gap in TTFT between static and dynamic shapes, but not the gap in ITL nor the big difference in throughput. Static shapes are enabled by default and allow Helion to optimize the performance using hard-coded tensor shapes in the generated Triton code. Static shapes are usually mentioned as performance optimization in Helion, but for highly dynamic usage scenarios like vLLM they are not.
More surprising are the corresponding results of the micro-benchmarks: Even if looking at the pure kernel performance, there is barely a difference between optimizing for static shapes or not. We suspect that the fact that the inputs to a paged attention kernel are actually quite shape-constrained contributes a lot to the minimal performance difference between static and dynamic shape compilation in Helion. For example, the size of the matrix multiplications always need to align with the KV cache page size (or block size) of vLLM and compile-time optimizations like loop fusion cannot change this.
One additional surprise in these end-to-end results was that the TTFT of Helion is actually on par with Triton in this specific benchmark, because here the prefill batches are larger and more uniform than in our microbenchmark setup.

Figure 8: End-to-end performance measurements using vllm bench with Llama3.1 8B and the ShareGPT dataset on MI300X.
On MI300X, our backend works as well, but achieves only 59% of the total token throughput using dynamic shapes. This result is not surprising, because already our micro-benchmark showed that on MI300X the gap between the Helion kernel and the Triton kernel for prefills is larger.
Lessons Learned and Conclusion
We enjoyed this experiment and see Helion as both a very handy language and a powerful autotuner. Overall we spent less than three weeks on the experiments described in this blog post and are quite surprised by the impressive results.
Of course, there are many options left to further optimize this implementation of paged attention in Helion. For example, to balance long prefills, long decodes, or very large batches. This would require implementing different launch grids or parallelizations along the queries in Helion and further research to determine the optimal heuristics to choose between different kernel versions, similar to how it is implemented in our Triton attention backend. This missing fine-grained optimization also explains the gap in performance between Triton and Helion implementation in the reported experiments. In the best case, we could teach the Helion autotuner to do this balance automatically (see further discussion). Since these trade-offs are all platform dependent, we think the autotuner of Helion is well-suited to automate this reliably and fast. However, also in this context, we need to find a good balance to re-trigger Helion tuning and JIT vs. low-latency execution of the kernels that were already compiled by using heuristics to select configurations (see further discussion).
Another possible optimization for our experimental vLLM backend would be to use static shapes during CUDA/HIP graph capture, since there the additional JIT overhead does not matter and the shapes of a recorded CUDA/HIP graph are static. Hence, here it would be safe to let the compiler optimize more aggressively taking the shapes into account. However, we then have to switch to dynamic shapes afterwards during runtime.
During our experiments, we realized that for development with Helion kernels, we also need an extensive, automated, and reliable micro-benchmark suite to understand the detailed performance of the kernel in a large number of use cases. This is similar to how we learned to develop our Triton kernels, and therefore, we adapted our micro-benchmark suite that we initially built for our work with Triton.
The single most-useful “Helion command” turned out to be tensor.view, to understand early if the Helion compiler considers the shape of tensors to be the same as we expect. This made debugging compiler errors that only print out symbolic shapes a lot easier.
Finally, we would like to add some pre-trained heuristics or decision trees to Helion, to have a middle ground between hour-long autotuning and just one configuration for all cases in low-latency scenarios such as vLLM.
In conclusion, we think Helion is an exciting and quite useful addition to the PyTorch ecosystem, and we are curious how it will impact vLLM.
Acknowledgments
This work was supported by the AI platform team at IBM Research, in particular we would like to thank our colleagues: Thomas Parnell, Jan van Lunteren, Mudhakar Srivatsa, and Raghu Ganti. Also, we would like to thank Jason Ansel and the Helion team at Meta for their feedback and support, especially the fast fixing of the bugs we reported, sometimes even within 24 hours.
Unlock Reasoning in Llama 3.1-8B via Full Fine-Tuning on NVIDIA DGX Spark
2 Feb 2026, 6:16 pmWhat is the unsaid joy of local LLMs?
The magic of downloading weights, running some experiments overnight, maybe your room gets a bit toasty, and voila, you create a small but performant model that runs on your desktop.
Often this involves a big GPU machine and lots of cables; in our case, it was a very lovely box that fit just within the spaces of the monitor stand and kept our hands warm. Truly, DGX Spark is really fun to look at!

In this blog, we share a recipe to run Full-Fine-Tuning for Llama 3.1-8B-Instruct on Synthetic data and unlock “Reasoning” in an LLM using the DGX Spark box. Thanks to the unified memory, we are able to generate synthetic thinking traces and fine-tune the model on it entirely locally.

Text in red color shows the added behaviour from the Fine-Tuned model on NVIDIA DGX Spark.
The entire recipe runs offline on DGX Spark in under a day. We are able to run full fine-tuning for Llama-3.1-8B-Instruct without any issues, with context length of 16k tokens and a batch size of 16
We plan to share more experiments on Llama 70B and FP4 experiments on even more exciting topics in a future blog. Stay tuned!
Adding Reasoning Behaviour in Llama-3.1-8BLarge Language Models’ ability to reason and think has shown large gains in practice, thanks to inference time scaling.
We ask the question: Can we create this behaviour for a specific topic by fine-tuning on synthetic thinking traces?
We prompt Llama 3.3-70B-Instruct running locally to add Chain of Thought to existing chats

Data Generation
Note: We generate synthetic CoT over the entire ToolACE dataset, which consists of 11k conversation pairs
This feature is supported out of box via Synthetic-Data-Kit, which offers an intuitive CLI to prepare and enrich your datasets for Fine-Tuning LLMs.
We can run this locally using vLLM on DGX Spark and use the following approach to generate CoT responses:
We use a single command
synthetic-data-kit create --type cot-enhance /path/to/dataset
Then we can create a custom prompt in a configuration file and use it like so:
# cot_tools_config.yaml
vllm: api_base: "http://localhost:8000/v1" model: "unsloth/Meta-Llama-3.3-70B-Instruct" max_retries: 3 retry_delay: 1.0 generation: temperature: 0.2 # Lower temperature for more consistent reasoning top_p: 0.95 max_tokens: 16384 # Allow for longer outputs to accommodate CoT reasoning # The most important part - our custom Chain of Thought prompt prompts: cot_enhancement: | You are a highly intelligent AI with an IQ of 170, and your job is to enhance existing conversation examples. Remember to return the entire conversation as is, BUT BUT, we will add Chain of Thought and planning to "Assistant" messages whenever they return a tool call. Remember, ONLY when an assistant message returns a tool call will we add thinking and reasoning traces before it to add logic. Otherwise, we don't touch the conversation history. Remember to return the entire message, but only enhance the assistant messages whenever a tool is called in the conversation by adding thoughts. Please keep in mind that we are not modifying anything in the example, nor are we changing what it does. We are only adding CoT every time a tool gets called in the conversation. Think out loud and maximize your tokens when adding CoT. For example, if you see: "from": "assistant", "value": "<tool>[Some API(param=\"value\")]</tool>" Change it to: "from": "assistant", "value": "Let me think about this request. I need to gather X information using Tool Y. To do this, I need to set the parameter to 'value' because of reason Z. <tool>[Some API(param=\"value\")]</tool>" BEGIN WORK NOW. Enhance the assistant's messages with detailed Chain of Thought reasoning before each tool call: {conversations}
Note: We generate synthetic CoT over the entire ToolACE dataset, which consists of 11k conversation pairs
synthetic-data-kit -c cot_tools_config.yaml create test_files/conversation_example.json \ --type cot-enhance \ –output enhanced_output/
Unlocking Local Full Fine Tuning with NVIDIA DGX Spark
Fine-tuning large language models is quite well understood, and there are a lot of knobs we can work with when performing supervised fine-tuning.
For our experiments, we follow Full-Fine Tuning to showcase the power of 128GB Unified memory of NVIDIA DGX Sparks.
128GB Unified Memory is Spacious!
The great thing about DGX Spark is: All of the available memory for training is exposed as a unified 128GB interface. So when performing Supervised Fine-Tuning, we can work with the assumption that we have a 128GB memory device to work with instead of spending time on offloading settings.
This enables us to run Full Fine-Tuning for Llama-3.1-8B instead of experimenting configurations to squeeze everything into working device memory.
Context Length
Bigger memory allows us to train experiments on longer contexts, make it easier to teach tasks like tool-calling.
For our use case, we fine-tune Llama on synthetic reasoning traces. These can get quite lengthy! In our case, our experiments run at 16k tokens.
Memory requirements quadratically increase as we increase sequence length with LLMs.
This is another area where DGX Sparks shine: Even at Full-Fine-Tune with 16k context length, we are at roughly 80% memory usage peak (Device Usage Graphs in results section below) .
Batch Size
Another great rule is to maximise batch sizes in powers of 2 to allow faster LLM convergence, for our experiments we have enough room to set batch_size at 16-this is really great!
Why Full Fine Tune?
Thanks to 128G, we are able to run 8B FFT entirely on DGX Spark, so we decided to follow this route for maximum performance and we report results from our experiments below
Full-Fine-Tuning
We use TorchTune for Full Fine-Tuning experiments. Torchtune offers an intuitive CLI that interacts with configs to offer a single line for running experiments.
The entire config and details are available here
We explain the key configurations below:
tune run full_finetune_single_device --config fft-8b.yaml
Key items we set in the config:
Seq_len: 16384
Batch_size: 16
Epochs: 3
Dtype: bf16
The whole fine-tuning pipeline from data generation to full fine-tuning is only a one-day job, with each epoch averaging just around 8 hours. This is really impressive!
Note: You can squeeze more performance, given that we note peak memory usage around 80% during the entire run.
Results
Below we show both the hero training run graphs followed by performance on function calling benchmarks:


For our recipe, we use the ToolACE dataset and baseline results on BFCLv3 as well as v4 (recently released)
BFCL measures the following:
- Multi-Turn Tool Calling
- Tool Calling in parallel support
- Ability to perform agentic calls (v4)
- Ability to perform web searches based on open-ended queries (v4)


Conclusion
We show an end to end recipe that runs on DGX Spark, utilizing its unified memory to perform Full-Fine-Tuning on Llama-3.1-8B-Instruct.
We generate synthetic thinking and chain of thought, which we then fine-tune the model to improve its performance reported above.
In future blogs, we will explore Llama-3.3-70B recipes as well as some more recipes that show FP4 power of the box. Please stay tuned!


