12 min read

🤔 Simple LLM Reasoning from Scratch

Table of Contents

Introduction

Recently, there’s been quite a bit of hype around teaching LLMs how to reason and learn. There’s been an emergence of quite a few libraries or methods purpose-built for teaching models to reason (Openr, Search and Learn, and TRL’s PRM Trainer, just to name a few). However, I’ve found that many of these are unapproachable for quickly hacking, given their dependencies for running LLMs in CUDA environments on-device. Furthermore, these libraries aren’t necessarily extensively documented, making them somewhat unapproachable for hacking. Thus, I provide this tutorial for teaching LLMs to reason from scratch on a consumer-grade computer.

Getting Started

For starters, we need to install some necessary libraries and tools. We’ll be working with Ollama for LLM inference, which is very straightforward to download on any desktop OS. Once you have Ollama installed, you can run the following to pull a lightweight Llama model. This will load the 3b model by default. For an extensive list of models that you can use, you can browse Ollama models.

ollama pull llama3.2

We’ll also go ahead and set up a Conda environment with all required packages. If you don’t use Conda, you can bypass the first two steps and install the libraries using a virtual environment in UV or Venv as well. From there, we’ll install the Ollama, Datasets, and NumPy packages to interract with the Ollama client, load mathematical reasoning datasets, and perform vector operations, respectively. We’ll also install jupyter to work with these tools in a Jupyter notebook.

conda create -n local-reasoning python=3.10
conda activate local-reasoning 
pip install ollama datasets numpy 

Creating an environment

A useful abstraction for working with step-by-step problem solving is treating it like a Markov Decision Process (MDP). Thus, we can create an environment that represents the state of learning how to solve a problem, and have our LLM act as an agent. For starters, we can create a utility function for calling our model, like below. Note that we forego setting more advanced sampling parameters such as min_p sampling for simplicity, but you can augment this code to support such sampling methods. We opt for a large temperature and top_p for a high diversity of sampled outputs. Importantly, we also specify that we wish to use raw prompting. When you call an LLM, a template is applied to the inputs. You can find this template on a given model’s page on Ollama. For example, the Llama 3.2 template can be found here. However, if you use a template in this code, you’ll see that the model will start from scratch with reasoning at every time step when we implement the environment class. Thus, we call raw=True to ensure that the template is not applied. Finally, you’ll see that we’re applying a stop token for calling the model. This is to specify that we wish to generate trees at a sentence-level, as is done in this paper. There are follow-up papers that propose using tokens to specify the ends of thoughts, though for simplicity and to use an out-of-the-box model, we will opt for sentence-level reasoning.

import ollama 

class LLMGen:
    def __init__(
        self,
        model: str,
        temperature: float = 1.0,
        top_p: float = 1.0,
        n: int = 128
    ):
        """
        A generic wrapper function for calling Ollama with some default arguments. Pulls the specified model, and when calling this class, will generate a response given the sampling parameters without the use of a template.

        Args:
            model (str): The model we wish to use. Make sure to pull this model using `ollama pull {YOUR MODEL NAME HERE}` before calling this model. 
            temperature (float): The temperature to use when sampling. A lower temperature means less variability in outputs, whereas a higher temperature means higher variability in outputs. 
            top_p (float): The nucleus sampling method for determining what tokens to consider for sampling. 
            n (int): The number of tokens to predict in a given step. 
        """
        self.model = model
        self.temperature = temperature
        self.top_p = top_p
        self.n = n

        ollama.pull(model)

    def __call__(self, content: str, image: str = None):
        return ollama.generate(
            self.model, 
            content,
            images = [image] if image else None,
            options={
                "temperature": self.temperature,
                "top_p": self.top_p,
                "num_predict": self.n,
                "stop": ["\n"]
            },
            raw=True,
        )

You can quickly test this model using the code below

LLMGen("llama3.2")("What is 2 + 2")

From here, we can go about creating an environment for solving math problems. Environments are inspired by the Gymnasium library, which abstracts MDPs using environment classes that support receiving observations from environments. This was similarly done by both Openr and TSLLM. We can first define a math problem dataclass which supports a question string, answer string, and an optional image that’s base64 encoded.

@dataclass
class Problem:
    """A dataclass representing a problem"""
    question: str 
    answer: str 
    image: Optional[str] = None

Since we’re using the environment to support Chain-of-Thought reasoning, we can name it the ChainOfThoughtEnv. This class will be an abstract base class as we can implement the reward for a given dataset. This allows us to support arbitrary reward engineering. In the simplest case, reward can be 0 or 1 for a question being answered incorrectly or correctly, respectively. However, you may wish to engineer the reward to penalize longer chains of thought. Another uniquely engineered reward may be if you are training agentic reasoning, and wish to evaluate a path holistically rather than solely based on it’s terminal state.

class ChainOfThoughtEnv(ABC):
    """
    An environment for solving natural language problems using Chain-of-Thought (CoT) reasoning.
    
    This class implements a structured environment for solving problems (particularly math problems)
    using language models with Chain-of-Thought prompting. It maintains the state of the problem-solving
    process, manages interactions with the language model, and tracks the solution progress.

    Attributes:
        sep (str): Separator string used for joining text elements, and identifying completion of thoughts
        action_history (Optional[List[str]]): List of actions/steps taken in the current solution
        legal_actions (List[str]): List of valid next steps that can be taken
        image (Optional[str]): Image associated with the problem (if applicable) encoded as a base64 image

    Example:
        >>> llm = LLMGen("llama3.2")
        >>> env = CoTEnv(
        ...     problems=[problem1, problem2],
        ...     llm_gen=llm,
        ...     task_description="Solve these math problems step by step",
        ...     cot_examples="Example1...",
        ...     problem_format="Problem: {question}\nSolution:",
        ... )
        >>> state, reward, terminated, truncated, info = env.reset()
        >>> while not terminated or truncated:
        ...     action = llm(state)
        ...     state, reward, terminated, truncated, info = env.step(action)
    """
    sep: str = "\n"
    action_history: Optional[List[str]] = None
    legal_actions: List[str] = [] 
    image: Optional[str] = None 

    def __init__(
        self,
        problem: Problem,
        llm_gen: LLMGen,
        task_description: str,
        cot_examples: str,
        problem_format: str,
        stop_string: str,
        max_actions: int = 2,
        max_steps: int = 10,
        is_few_shot: bool = True,
        reset: bool = False,
    ):
        """
        Initialize the CoT environment.

        Args:
            problem: Problem to be solved
            llm_gen: Language model generator function/class for generating steps
            task_description: Description of the task to be solved
            cot_examples: Example problems and solutions for few-shot learning
            problem_format: Format string for presenting problems (must contain {question})
            stop_string: String specifying that the LLM is done it's chain of thought
            max_actions: Maximum number of actions to generate per step
            max_steps: Maximum length of solution steps
            is_few_shot: Whether to use few-shot learning with examples
            reset: Whether to reset the environment upon initialization
        """
        self.problem = problem    
        self.llm_gen_fn = llm_gen
        self.is_few_shot = is_few_shot
        self.max_actions = max_actions
        self.max_steps = max_steps

        self.task_description = task_description
        self.cot_examples = cot_examples
        self.problem_format = problem_format
        self.stop_string = stop_string

        prefixes = []
        if self.task_description is not None:
            prefixes.append(self.task_description)
        if self.is_few_shot:
            prefixes.append(self.cot_examples)
        if len(prefixes) > 0:
            self.task_prefix = self.sep.join(prefixes)
        else:
            self.task_prefix = None

        if reset:
            self.reset()

    def build_query_str(
        self,
    ) -> str:
        """
        Builds a formatted query string for the problem.

        Combines task description, examples (if few-shot), and the problem input
        into a formatted string ready for the language model.

        Returns:
            str: Formatted query string

        Example:
            >>> query = CoTEnv.build_query_str(
            ...     "Solve step by step",
            ...     "Example: 2+2=4",
            ...     "Problem: {question}",
            ...     "What is 3+3?",
            ...     True
            ... )
        """
        ret = ""
        if self.task_description:
            ret += self.task_description + "\n"
        if self.is_few_shot:
            ret += self.cot_examples + "\n"
        ret += self.problem_format.format(question=self.problem["question"])
        return ret

    def reset(self) -> Tuple[Tuple[str, Optional[str]], float, bool, bool, dict]:
        """
        Resets the environment to its initial state.

        Sets up a new problem, clears action history, and optionally updates legal actions.

        Returns:
           Tuple containing:
            - Current state (Tuple[str, Optional[str]])
            - Reward (float)
            - Whether episode is terminated (bool)
            - Whether episode is truncated (bool)
            - Additional info dictionary

        Raises:
            ResetException: If unable to establish legal actions after 3 attempts
        """
        self.image = self.problem.get("image", None)
        self.action_history = [
            self.build_query_str()
        ]
        state = self.get_state()
        return state, 0., False, False, {}
        
    def step(self, action: str) -> Tuple[Tuple[str, Optional[str]], float, bool, bool, dict]:
        """
        Takes a step in the environment by applying an action.

        Args:
            action: The action to take (next solution step)

        Returns:
            Tuple containing:
            - Current state (Tuple[str, Optional[str]])
            - Reward (float)
            - Whether episode is terminated (bool)
            - Whether episode is truncated (bool)
            - Additional info dictionary
        """
        self.action_history.append(action)
        state = self.get_state()
        reward = self.get_reward()
        terminated, truncated, info = self.get_done_and_info()
        return state, reward, terminated, truncated, info

    def get_state(self) -> Tuple[str, Optional[str]]:
        """
        Returns the current state of the environment.

        Returns:
            Tuple containing:
            - String representation of current state (action history)
            - Image associated with the problem (if applicable) encoded as a base64 image
        """
        return "\n".join(self.action_history) + "\n", self.image

    def get_done_and_info(self) -> Tuple[bool, bool, dict]:
        """
        Determines if the current episode is complete and provides additional information.

        Returns:
            Tuple containing:
            - Whether the episode is terminated (reached stop condition)
            - Whether the episode is truncated (reached max length)
            - Info dictionary with additional details (including winner)

        Note:
            winner codes:
            0: ongoing
            1: successful completion
            2: unsuccessful completion
        """
        info = {"winner": 0}
        terminated = self.stop_string in self.action_history[-1]
        max_steps = self.max_steps + (2 if self.task_prefix is not None else 1)
        
        truncated = len(self.action_history) >= max_steps
        assert len(self.action_history) <= max_steps, (
            f"action history length: {len(self.action_history)}, "
            f"max length: {max_steps}"
        )

        if terminated or truncated:
            info["winner"] = self.is_valid()
        
        return terminated, truncated, info

    @abstractmethod 
    def is_valid(self) -> Literal[1,2]:
        pass

    @abstractmethod
    def get_reward(self) -> float:
        pass 

An example environment that we can implement is for the Grade School Math dataset.

class GSM8KEnv(ChainOfThoughtEnv):
    def is_valid(self) -> Literal[1,2]:
        """Extract the answer using regex, and then check it against the ground truth"""
        answer_regex = re.compile(r"The answer is (\-?[0-9\.\,]+)")
        match = answer_regex.search(self.action_history[-1])

        if match:
            match_str = match.group(1).strip()
            match_str = match_str.replace(",", "") or 'inf'
            is_correct = float(match_str) - float(self.problem["answer"]) < 1e-8
            return 1 if is_correct else 2
            
        else:
            match_str = ""
            return 2

    def get_reward(self) -> float:
        """Extract the answer using regex, and then check it against the ground truth"""
        answer_regex = re.compile(r"The answer is (\-?[0-9\.\,]+)")
        match = answer_regex.search(self.action_history[-1])

        if match:
            match_str = match.group(1).strip()
            match_str = match_str.replace(",", "") or 'inf'
            is_correct = float(match_str) - float(self.problem["answer"]) < 1e-8
            return 1. if is_correct else 0.
            
        else:
            match_str = ""
            return 0.

To work with the above environment in a typical RL scenario, we can use the following example code. Please note that we have to extract the numerical answer from the dataset as Grade School Math includes reasoning for answers already, followed by the actual answer after the string ”#### ”

llm_gen =  LLMGen("llama3.2:3b")
dataset = load_dataset("openai/gsm8k", 'main')

question = dataset['test']['question'][42]
answer = dataset['test']['answer'][42].split("#### ")[-1]

env = GSM8KEnv(
    problem = {"question": dataset['test']['question'][0], "answer": dataset['test']['answer'][0], "image": ""},
    llm_gen = llm_gen,
    task_description="Answer the following math questions",
    problem_format = "Question: {question}\nAnswer: Let's think step by step",
    cot_examples = (
        "Question: There are 15 trees in the grove. Grove workers will plant trees"
        "in the grove today. After they are done, there will be 21 trees. How many"
        "trees did the grove workers plant today?\nAnswer: Let's think step by step"
        "\nThere are 15 trees originally.\nThen there were 21 trees after some more"
        "were planted.\nSo there must have been 21 - 15 = 6.\nThe answer is 6\n\n"
        "Question: If there are 3 cars in the parking lot and 2 more cars arrive, "
        "how many cars are in the parking lot?\nAnswer: Let's think step by step\n"
        "There are originally 3 cars.\n2 more cars arrive.\n3 + 2 = 5.\nThe answer "
        "is 5\n\nQuestion: Leah had 32 chocolates and her sister had 42. If they ate"
        "35, how many pieces do they have left in total?\nAnswer: Let's think step "
        "by step\nOriginally, Leah had 32 chocolates.\nHer sister had 42.\nSo in "
        "total they had 32 + 42 = 74.\nAfter eating 35, they had 74 - 35 = 39.\nThe "
        "answer is 39\n\nQuestion: Jason had 20 lollipops. He gave Denny some "
        "lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to "
        "Denny?\nAnswer: Let's think step by step\nJason started with 20 lollipops."
        "\nThen he had 12 after giving some to Denny.\nSo he gave Denny 20 - 12 = 8."
        "\nThe answer is 8\n\nQuestion: Shawn has five toys. For Christmas, he got "
        "two toys each from his mom and dad. How many toys does he have now?\nAnswer:"
        "Let's think step by step\nShawn started with 5 toys.\nIf he got 2 toys each "
        "from his mom and dad, then that is 4 more toys.\n5 + 4 = 9.\nThe answer is 9"
        "\n\nQuestion: There were nine computers in the server room. Five more "
        "computers were installed each day, from monday to thursday. How many computers"
        " are now in the server room?\nAnswer: Let's think step by step\nThere were "
        "originally 9 computers.\nFor each of 4 days, 5 more computers were added.\n"
        "So 5 * 4 = 20 computers were added.\n9 + 20 is 29.\nThe answer is 29"
    ),
    stop_string = "The answer is"
)

To actually execute the environment, we can run the following code. If the LLM gets the correct answer, we should see a reward of one. Otherwise, the reward will be zero.

state, reward, terminated, truncated, info = env.reset()

while True: 
    response = llm_gen(state[0]).response
    state, reward, terminated, truncated, info = env.step(response)
    if terminated or truncated:
        print(env.get_state(), reward)
        break