Skip to content

ayushh0110/toolforge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🔧 ToolForge: Fine-Tuning Small LLMs for Autonomous Tool Routing

Teaching a model to become the router — replacing hand-crafted heuristics with learned tool-selection behavior via QLoRA distillation.

Python 3.12 PyTorch HuggingFace W&B

📖 Read the blog post: From Heuristics to Fine-Tuning


🎯 Problem

Autonomous AI agents need to decide which tool to call for every user query. Most implementations rely on:

  • ❌ Regex/keyword matching (brittle, unmaintainable)
  • ❌ Zero-shot LLM prompting (expensive, slow, inconsistent)
  • ❌ Embedding similarity (loses argument extraction)

ToolForge solves this by fine-tuning a small LLM (7-8B) via QLoRA on synthetic tool-call traces — replacing a heuristic router with a learned one. On a clean, hand-written, non-circular test set with format-agnostic grading, fine-tuning improves tool-routing accuracy from 75.0% (base Qwen2.5-7B) to 83.3% (+8.3 pp) — and the gain comes from better routing decisions, not output formatting (see Honest Evaluation).


📊 Results

Ablation Study (4 runs, W&B tracked)

Run Base Model LoRA r LR Test Accuracy Eval Loss
🥇 qwen7b-r64 Qwen2.5-7B-Instruct 64 2e-4 86.2% 0.141
🥈 mistral-r64 Mistral-7B-Instruct-v0.3 64 2e-4 82.8% 0.670
🥉 mistral-r16 Mistral-7B-Instruct-v0.3 16 2e-4 81.9% 0.648
❌ mistral-lr5e4 Mistral-7B-Instruct-v0.3 64 5e-4 60.3% 0.730

Note: These accuracies are measured on a held-out split of the same synthetic distribution used for training (labels partly from the Gemini teacher). Because the teacher labels both train and test, this number is partly circular and should be read as an internal ablation comparing hyperparameters — not as an unbiased estimate of routing quality. For the unbiased, non-circular measurement, see Honest Evaluation below.


🔬 Honest Evaluation: Base vs Fine-Tuned

The ablation above answers "which hyperparameters are best?" but not "did fine-tuning actually add anything over the base model, or is Qwen2.5-7B already good at this?" — and it's measured on teacher-labeled data, which is circular.

To answer honestly, I built a separate, hand-written, non-circular test set (36 realistic, indirectly-phrased queries, labeled by hand — no teacher involved) and compared the base model against the fine-tuned adapter on identical inputs.

Fair grading. The fine-tuned model is trained to emit a specific <tool_calls>[...]</tool_calls> format. To avoid penalizing the base model purely for using a different format, grading is format-agnostic: a prediction counts if the correct tool is identified in any recognizable format (the trained format, Qwen-native <tool_call>, raw JSON, or function-call style).

Model Routing accuracy (format-agnostic) Strict-format accuracy Avg latency*
Base Qwen2.5-7B-Instruct 75.0% 75.0% 3,457 ms
Fine-tuned (QLoRA r=64) 83.3% 83.3% 5,322 ms
Gain +8.3 pp +8.3 pp

Unbatched HuggingFace generate() on a single T4 — not a production serving setup; not a fair latency comparison.

What this shows

  • The gain is real routing, not formatting. Strict and lenient scores are identical for both models — base Qwen already emits parseable formats. So fine-tuning improved which tool the model picks, not how it writes the call.
  • Gains concentrate on disambiguation: web_search vs wikipedia (33% → 100%), unit_converter vs calculator (67% → 100%), and multi-tool selection (67% → 100%).
  • Honest tradeoff: fine-tuning slightly increases over-triggering on no-tool conversational queries (e.g. "what is 2 plus 2", "I'm bored") — a precision/recall cost of teaching the model to reach for tools. This is a known side effect of tool-routing fine-tuning, reported here rather than hidden.

Known limitations

  • Fixed tool set. ToolForge learns 9 specific tools baked into the prompt; adding a tool requires retraining. It is a specialist router, trading the generality of frontier function-calling for a small, cheap, self-hostable model — the right tradeoff when the tool set is known and fixed.
  • Latency above is not a serving benchmark (unbatched HF generate). A vLLM/batched setup is the correct way to measure production latency; that comparison is future work.

Per-Tool Accuracy (Best Model — Qwen2.5-7B)

Tool Accuracy
datetime 100%
unit_converter 100%
web_reader 100%
calculator 94.1%
dictionary 93.8%
weather 92.3%
web_search 91.7%
wikipedia 86.7%
translate 80.0%
multi_tool 50.0%
no_tool 41.7%

Per-tool numbers are from the internal (synthetic) test split — see the circularity note above. The no_tool / multi_tool figures in particular are affected by teacher-label noise; the Honest Evaluation section is the unbiased measurement.

Key Findings

  • 7/9 tools above 90% — single-tool routing is near-production quality
  • Adapter size has minimal impact — r=16 (81.9%) vs r=64 (82.8%); smaller adapter is deployable for efficiency
  • Learning rate is critical — 5e-4 causes divergence; 2e-4 is the sweet spot
  • Student surpasses teacher — on the hand-written test set, the fine-tuned model correctly routed several queries that the Gemini teacher would mislabel (e.g. current-events → web_search rather than wikipedia), confirmed by manual review of disagreements

🏗️ Architecture

┌─────────────────────────────────────────────────────────────┐
│                    ToolForge Pipeline                        │
│                                                             │
│  Phase 1: Data Generation                                   │
│  ┌──────────────┐   ┌──────────────┐   ┌────────────────┐  │
│  │   Template    │ + │   Gemini     │ → │  1,173 labeled │  │
│  │  Generator    │   │  Teacher     │   │   examples     │  │
│  │  (498 seed)   │   │  (679 dist.) │   │  (train/val/   │  │
│  │              │   │  flash+lite  │   │   test/hard)   │  │
│  └──────────────┘   └──────────────┘   └────────────────┘  │
│                                                             │
│  Phase 2: QLoRA Training                                    │
│  ┌──────────────┐   ┌──────────────┐   ┌────────────────┐  │
│  │  Base Model   │ + │  LoRA r=64   │ → │   Fine-tuned   │  │
│  │  (4-bit NF4)  │   │  Adapter     │   │   Router       │  │
│  │  Qwen/Mistral │   │  ~335-646 MB │   │  +8.3pp vs base│  │
│  └──────────────┘   └──────────────┘   └────────────────┘  │
│                                                             │
│  Phase 3: Evaluation                                        │
│  ┌──────────────┐   ┌──────────────┐   ┌────────────────┐  │
│  │  Tool Acc.    │   │  Per-Category│   │   W&B          │  │
│  │  Arg Match    │   │  Breakdown   │   │   Dashboard    │  │
│  │  Multi-Tool   │   │  Error       │   │   4 ablation   │  │
│  │  Latency      │   │  Analysis    │   │   runs         │  │
│  └──────────────┘   └──────────────┘   └────────────────┘  │
└─────────────────────────────────────────────────────────────┘

🛠️ The 9 Tools

The model learns to route queries to these tools (or respond directly):

Tool Description Input Schema
web_search Search the internet {query: str}
calculator Mathematical expressions {expression: str}
weather Current weather data {location: str}
wikipedia Encyclopedia lookup {query: str}
datetime Date/time operations {action: str, ...}
dictionary Word definitions {word: str}
translate Language translation {text: str, to_lang: str}
unit_converter Unit conversion {value: float, from: str, to: str}
web_reader Extract webpage content {url: str}

Plus no_tool (direct response) and multi_tool (chained calls).


📁 Project Structure

toolforge/
├── README.md
├── requirements.txt
├── configs/
│   ├── mistral_r64.yaml            # Default training config
│   ├── mistral_r16.yaml            # Small adapter ablation
│   └── llama_r64.yaml              # Alternative base model
├── data/
│   ├── synthetic/
│   │   ├── queries.json            # 1,894 generated queries
│   │   └── teacher.jsonl           # 679 Gemini-labeled examples
│   ├── train.jsonl                 # 918 training examples
│   ├── val.jsonl                   # 114 validation examples
│   ├── test.jsonl                  # 116 test examples
│   └── hard_test.jsonl             # 25 multi-tool edge cases
├── src/
│   ├── data_gen/
│   │   ├── template_generator.py   # Deterministic seed data (498 examples)
│   │   ├── teacher_labeler.py      # Gemini distillation with multi-key rotation
│   │   └── build_dataset.py        # Merge, dedup, split into train/val/test
│   ├── training/
│   │   ├── train.py                # QLoRA fine-tuning with SFTTrainer
│   │   └── merge.py                # LoRA → base model merge for deployment
│   └── eval/
│       └── evaluate.py             # Tool accuracy, F1, per-category breakdown
├── kaggle_ablation.py              # Self-contained Kaggle notebook with W&B
└── kaggle_notebook.py              # Single-run training notebook

🚀 Quick Start

1. Generate Training Data

# Install dependencies
pip install -r requirements.txt

# Generate seed queries + label with Gemini
# (requires API keys in .env — get free keys at aistudio.google.com)
python -m src.data_gen.teacher_labeler --n 2500

# Build final dataset splits
python -m src.data_gen.build_dataset

2. Train on Kaggle (Free GPU)

  1. Upload data/*.jsonl as a Kaggle Dataset
  2. Create a new notebook with GPU T4 enabled
  3. Paste cells from kaggle_ablation.py and run
# Or train locally with a GPU
python -m src.training.train --config configs/mistral_r64.yaml

3. Evaluate

python -m src.eval.evaluate \
    --checkpoint checkpoints/qwen7b-r64-lr2e4/final \
    --test-set data/test.jsonl

🔬 Data Pipeline

Two-Source Strategy

Source Count Method Quality
Template Generator 498 Deterministic rules, 100% clean labels ⭐⭐⭐
Gemini Distillation 679 gemini-2.5-flash + flash-lite function calling ⭐⭐

Crash-Proof Distillation

The teacher labeler (teacher_labeler.py) is designed for zero-cost, zero-data-loss operation:

  • Multi-key round-robin: 6 API keys × 2 models = 12 independent quota slots
  • Incremental saves: Every label is flushed to disk immediately
  • Smart retry logic: Distinguishes daily quota (mark key dead) vs transient 503 (exponential backoff)
  • Resume support: --resume flag continues from exactly where you left off
# Resume after quota exhaustion — add fresh keys to .env and re-run
python -m src.data_gen.teacher_labeler --resume

⚙️ Training Details

QLoRA Configuration

Parameter Value
Quantization 4-bit NF4, double quantization
LoRA rank 64 (best), 16 (ablation)
LoRA alpha 128
Target modules q, k, v, o, gate, up, down projections
Optimizer AdamW
Learning rate 2e-4 (cosine schedule)
Batch size 4 × 4 gradient accumulation = 16 effective
Epochs 3
Trainable params ~335M / 7.2B (4.6%)

Training Curves (Mistral-7B, r=64)

Step   Train Loss   Eval Loss
  50     0.724        0.698
 100     0.581        0.687
 150     0.495        0.672

📈 Experiment Tracking

All runs are logged to Weights & Biases under the toolforge project:

  • Training loss curves (per-step)
  • Validation loss at each checkpoint
  • Test accuracy and per-category breakdown
  • Hyperparameter comparison across ablation runs
  • System metrics (GPU utilization, memory)

🧠 Key Technical Decisions

Why QLoRA over full fine-tuning?

With 918 training examples and a 7B model, full fine-tuning would catastrophically overfit. QLoRA freezes 95%+ of weights and only trains ~335M adapter parameters — enough capacity for tool routing without destroying the base model's knowledge.

Why Gemini as teacher instead of GPT-4?

Cost. Gemini's free tier provides 20+ requests/day per model per key. With 6 keys × 2 models = 12 quota slots, we labeled 679 examples at zero cost. The multi-key rotation system makes this fully automated.

Why the student outperforms the teacher's labels

The model sees 27/30 correct labels for patterns like "define X → dictionary" and learns the dominant signal. The 3 noisy labels from Gemini's inconsistency are treated as noise — a well-known property of neural network training on noisy supervision.


📋 Requirements

  • Python 3.12+
  • PyTorch 2.x with CUDA
  • transformers, peft, trl, bitsandbytes
  • Google API keys (free tier) for data generation
  • GPU: T4 (16GB) minimum for training

📝 License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages