Machine Learning

Training Models With Imbalanced Datasets — Machine Learning Classification

In this article we will review what approach to take when facing the issue of imbalanced classes.

Dataset

This is the most important bit, because it is where the problem lives. Having one or more classes under-represented will make the training loop less efficient at predicting the low-sampled classes. Models try to generalize, and here is where the model can miss with its' predictions.

Why would someone care about an issue that is present in less than 1% of the sample? Well, sometimes this problem is impactful in terms of value, fraud is a very good example here. But imbalances can vary as long as you have one category that is not well represented compared to others.

The first proposed approach would be to try to break that imbalance by balancing them. Two well known approaches are Under-Sampling and Over-Sampling. One key call out, you should never validate your results with a dataset that has been modified, these techniques should only be applied to the training set, doing so in validation might result in misleading metrics.

First step is to set up a benchmark. I have chosen this public dataset:

https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv

data = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')

Initially we have these break of classes:

Class 0    0.998273 1    0.001727

We will be setting up a benchmark model to test our results against, and a metric which in this case will be the accuracy. The model chosen to train is a Gradient Boost Classifier from Scikit Learn, trained on a random train sample and test on a test sample, also random.

train, test = train_test_split(data, test_size=0.2, random_state=42, stratify=data['Class']) model = GradientBoostingClassifier() model.fit(train.drop('Class', axis=1), train['Class']) accuracy_score(test['Class'], model.predict(test.drop('Class', axis=1)))
0.9983146659176293

The very first thing that stands out is that the model is actually learning something, see how the accuracy is higher than the value of the main class. This is a promising first output of our model.

Now we will do a deeper dive into how we can improve these results. Firstly, we have under-sampling, and this is a hyper-parameter of the model, but not a direct one. Why is this? Well, we have to find a sweet spot for under-sampling, do it too much or too little and the model will not improve. In our case we are giving the below breaks for the training sample only, the test will remain invariant.

train = pd.concat([train[train['Class']==1], train[train['Class']==0].sample(n=int(len(train[train['Class']==1])*200), random_state=42)])
Class 0    0.995025 1    0.004975

Note how now we have changed the breaks between classes, and we have chosen a random sample of 200 times the size of our class 1. Re-doing the training loop, with this train sample, is showing an improvement against the benchmark.

accuracy_score(test['Class'], model.predict(test.drop('Class', axis=1)))
0.9984902215512096

We could go ahead and do a hyper-parameter search for the best possible break of classes. We will use Optuna to do this search, but it could be done with any other method.

def objective(trial):     undersampling = trial.suggest_int('undersampling', 1, int(len(train[train['Class']==0])/len(train[train['Class']==1])))     train_sample = pd.concat([         train[train['Class']==1],         train[train['Class']==0].sample(n=int(len(train[train['Class']==1])*undersampling), random_state=42),     ])     model = GradientBoostingClassifier()     model.fit(train_sample.drop('Class', axis=1), train_sample['Class'])     return accuracy_score(test['Class'], model.predict(test.drop('Class', axis=1))) study = optuna.create_study(direction='maximize') study.optimize(objective, n_trials=10) study.best_trial
FrozenTrial(number=5, state=<TrialState.COMPLETE: 1>, values=[0.999420666409185], datetime_start=datetime.datetime(2025, 12, 7, 18, 52, 39, 10131), datetime_complete=datetime.datetime(2025, 12, 7, 18, 54, 58, 845295), params={'undersampling': 475}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'undersampling': IntDistribution(high=577, log=False, low=1, step=1)}, trial_id=5, value=None)

Optuna is designed to automate hyperparameter tuning in machine learning. Instead of relying on manual trial-and-error, Optuna leverages advanced algorithms to efficiently explore parameter spaces and identify optimal configurations. It has automated search, eliminating guesswork by systematically testing hyperparameters. It uses samplers like Tree-structured Parzen Estimator (TPE) and pruners to stop weak trials early and integrates seamlessly with PyTorch, TensorFlow, scikit-learn, XGBoost, LightGBM, and more.

It is key to highlight that the highest accuracy returned was 99.94, which is a huge improvement from the 99.83 of our benchmark.

Next, we can change the approach and instead of removing data points, we can create new ones. With a Random sampler from Imbalanced Learn, this selects the minority class, where samples are chosen randomly. It does sampling with replacement — the same sample can be duplicated multiple times. There is no new information; unlike synthetic methods (e.g., SMOTE), random oversampling doesn't create new data — it just replicates existing points.

ros = RandomOverSampler(random_state=0) train_resampled_x, train_resampled_y = ros.fit_resample(train.drop('Class', axis=1), train['Class']) model = GradientBoostingClassifier() model.fit(train_resampled_x, train_resampled_y) accuracy_score(test['Class'], model.predict(test.drop('Class', axis=1)))
0.9923984410659739

As seen, we have also improved the baseline accuracy of our model with this method. It highlights how important it is to work on the dataset so the model can train with relevant data points and generate the best output possible.

Model

In this section, we focus on the second approach. Many machine learning algorithms assume that all classes are equally important, which can lead to poor performance on minority classes. To counter this, we can assign class weights during training. Class weights tell the model to "pay more attention" to underrepresented classes by increasing their contribution to the loss function.

Why class weights matter: Without weighting, the model tends to optimize for overall accuracy, which can be misleading in imbalanced scenarios. For example, if 99% of your data belongs to one class, a naive model could achieve 99% accuracy by predicting the majority class every time — while completely ignoring the minority class. Class weights help mitigate this by penalizing misclassification of minority samples more heavily.

For this we will use a parameter on the .fit() method called sample_weight, but we first have to compute the weights based on our sample.

class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(train['Class']), y=train['Class']) sample_weight = np.array([class_weights[cls] for cls in train['Class']]) model = GradientBoostingClassifier() model.fit(train.drop('Class', axis=1), train['Class'], sample_weight=sample_weight) accuracy_score(test['Class'], model.predict(test.drop('Class', axis=1)))
0.99252133000948

Also, we have an improvement on the accuracy of our model by doing these extra steps.

Conclusions

Working with imbalanced datasets is not just a technical nuisance — it directly impacts the reliability of machine learning models in high-stakes domains such as fraud detection, medical diagnosis, or security. Through this exploration, several key insights emerge:

  • Benchmarking matters: Establishing a baseline model and metric is essential before applying any balancing technique. It ensures improvements are measured against a meaningful reference.
  • Resampling strategies: Both under-sampling and over-sampling can shift class distributions to help the model learn minority classes. However, they must be applied only to the training set to avoid misleading validation results.
  • Hyperparameter tuning: Automated search tools like Optuna can uncover optimal sampling ratios, demonstrating that systematic exploration often outperforms manual trial-and-error.
  • Model-side adjustments: Techniques such as class weights allow the algorithm itself to account for imbalance, complementing dataset-level interventions.
  • Accuracy gains: Even small improvements (from ~99.83% to ~99.94%) highlight how balancing strategies can make models more sensitive to rare but critical cases.

Ultimately, handling imbalance requires a two-pronged approach: reshaping the dataset and guiding the model to pay attention to minority classes. By combining resampling, hyperparameter optimization, and class weighting, practitioners can build models that are not only accurate but also fairer and more robust in real-world scenarios.

GitHub Repository

You can find the full code, experiments, and reproducible notebooks here:

https://github.com/gasparcartasso/imbalance_training