Enhancing AI Reasoning with Step-Wise Reinforcement Learning
Researchers from Stanford University and Google DeepMind have introduced Step-Wise Reinforcement Learning (SWiRL), a novel technique designed to improve the ability of large language models (LLMs) to handle complex tasks requiring multi-step reasoning and tool use. This development comes as interest in AI agents and LLM tool use continues to grow, potentially offering significant benefits for enterprises looking to integrate reasoning models into their applications and workflows.
The Challenge of Multi-Step Problems
Real-world enterprise applications often involve multi-step processes, such as planning complex marketing campaigns that require market research, internal data analysis, budget calculation, and reviewing customer support tickets. These tasks necessitate online searches, access to internal databases, and running code. Traditional reinforcement learning (RL) methods used to fine-tune LLMs, such as Reinforcement Learning from Human Feedback (RLHF) or RL from AI Feedback (RLAIF), typically focus on optimizing models for single-step reasoning tasks.

The lead authors of the SWiRL paper, Anna Goldie and Azalia Mirhosseini, believe that current LLM training methods are not suited for the multi-step reasoning tasks required by real-world applications. They noted that LLMs trained via traditional methods struggle with multi-step planning and tool integration, making it difficult to perform tasks that require retrieving and synthesizing documents from multiple sources or multiple steps of reasoning and arithmetic calculation.
How SWiRL Works
SWiRL addresses the multi-step challenge through a combination of synthetic data generation and a specialized RL approach that trains models on entire sequences of actions. The technique employs a two-stage methodology. First, it generates and filters large amounts of multi-step reasoning and tool-use data. Second, it uses a step-wise RL algorithm to optimize a base LLM using these generated trajectories.

The data generation process involves giving an LLM access to relevant tools like search engines or calculators. The model is then prompted iteratively to generate a ‘trajectory,’ a sequence of steps to solve a given problem. At each step, the model can generate internal reasoning, call a tool, or produce the final answer. If it calls a tool, the query is extracted, executed, and the result is fed back into the model’s context for the next step.

SWiRL achieved its best results using process-filtered data, which includes trajectories where each reasoning step or tool call was deemed logical given the previous context, even if the final answer turned out to be wrong. The researchers found that SWiRL can learn even from trajectories that end in incorrect final answers.
Evaluating SWiRL’s Performance
The Stanford and Google DeepMind team evaluated SWiRL across several challenging multi-step question-answering and mathematical reasoning tasks. Compared to baseline models, SWiRL demonstrated significant relative accuracy improvements, ranging from 11% to over 21% on datasets like GSM8K, HotPotQA, MuSiQue, and BeerQA. The experiments confirmed that training a Gemma 2-27B model with SWiRL on process-filtered data yielded the best results.
Moreover, SWiRL exhibited strong generalization capabilities. Training a model using SWiRL on text-based question-answering examples improved its performance on math reasoning tasks, even though the model wasn’t explicitly trained on math problems. This transferability across different tasks and tool types is highly valuable as there is an explosion of agentic applications for language models.
The researchers noted that SWiRL’s generalization seems quite robust in the domains they explored, but they expressed interest in testing this in other areas such as coding. They believe that an enterprise AI model trained on one core task using SWiRL would likely exhibit significant performance improvements on other, seemingly unrelated tasks without task-specific fine-tuning.