Text Classification Using BERT: A Detailed Guide for Intermediate Practitioners
- Sairam Penjarla
- Jun 27, 2024
- 4 min read
In this blog, we'll walk through a text classification project using BERT (Bidirectional Encoder Representations from Transformers) with a focus on the German language. We will explain the theory behind the model and then delve into the code, block by block, to provide a comprehensive understanding of the entire process.
GitHub URL:
It is highly recomended to go through the below repo to get the full version of the code along with the necessary filels such as requirements.txt file and csv file.
Theory Behind BERT
What is BERT?
BERT, developed by Google, is a transformer-based model designed for natural language understanding. Unlike traditional models, BERT uses bidirectional context, meaning it looks at the entire sentence (both left and right context) to understand the meaning of a word. This allows BERT to achieve state-of-the-art performance on various NLP tasks, including text classification.
Text Classification with BERT
Text classification involves assigning predefined categories to text documents. BERT can be fine-tuned for specific tasks like sentiment analysis, spam detection, or in our case, identifying different types of abusive language in German tweets.
Code Explanation
Let's dive into the code, which is designed to classify German tweets into different categories of abusive language using a pre-trained BERT model.
1. Importing Required Libraries
Here, we'll import the necessary libraries for our project. These libraries include tools for data manipulation, evaluation metrics, and handling transformer models. This step ensures we have all the resources needed for data processing, model training, and evaluation.
import os
import torch
import tarfile
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split
from simpletransformers.classification import ClassificationModel
os, torch, tarfile: Standard Python libraries for file operations, PyTorch tensor operations, and handling tar files.
pandas: For data manipulation.
sklearn: For evaluation metrics and splitting the dataset.
simpletransformers: A wrapper for the Hugging Face Transformers library, making it easy to use transformer models.
2. Loading and Preprocessing Data
In this step, we'll load the dataset containing German tweets and their labels. We'll clean and preprocess the data to ensure it's in the right format for training our model. This involves combining datasets, cleaning labels, and mapping text labels to numerical values, which is crucial for the model to understand the data.
class_list = ['INSULT','ABUSE','PROFANITY','OTHER', 'EXPLICIT', 'IMPLICIT']
df1 = pd.read_csv('Shared-Task-2019_Data_germeval2019.training_subtask1_2.txt', sep='\\\\t', lineterminator='\\\\n', encoding='utf8', names=["tweet", "task1", "task2"])
df2 = pd.read_csv('Shared-Task-2019_Data_germeval2019.training_subtask3.txt', sep='\\\\t', lineterminator='\\\\n', encoding='utf8', names=["tweet", "task1", "task2"])
df = pd.concat([df1, df2])
df['task2'] = df['task2'].str.replace('\\\\r', "")
df['pred_class'] = df.apply(lambda x: class_list.index(x['task2']), axis=1)
df = df[['tweet', 'pred_class']]
print(df.shape)
df.head()
class_list: Defines the categories of abusive language.
df1 and df2: Load datasets with tweets and their corresponding labels.
pd.concat([df1, df2]): Combine both datasets.
df['task2'].str.replace('\\\\r', ""): Clean the labels.
df.apply(lambda x: class_list.index(x['task2']), axis=1): Map the labels to numerical values.
df = df[['tweet', 'pred_class']]: Keep only the tweet and the predicted class columns.
3. Splitting Data into Training and Testing Sets
We'll divide the dataset into training and testing sets. The training set will be used to train the BERT model, while the testing set will be used to evaluate its performance. This split is necessary to ensure that we can assess how well our model generalizes to unseen data.
train_df, test_df = train_test_split(df, test_size=0.10)
print('train shape: ', train_df.shape)
print('test shape: ', test_df.shape)
train_test_split: Splits the dataset into training and testing sets with a 90/10 ratio.
4. Setting Up the BERT Model
Here, we'll configure and initialize the BERT model using the simpletransformers library. We'll specify training parameters such as the number of epochs and whether to use a GPU. Setting up the model correctly is essential for effective training and optimal performance.
train_args = {
"reprocess_input_data": True,
"fp16": False,
"num_train_epochs": 4
}
model = ClassificationModel(
"bert", "distilbert-base-german-cased",
num_labels=len(train_df['pred_class'].unique()),
args=train_args,
use_cuda=torch.cuda.is_available()
)
model.train_model(train_df)
train_args: Configuration for training (reprocessing data, disabling fp16 for mixed precision, setting number of epochs).
ClassificationModel: Initializes a BERT model (distilbert-base-german-cased) for classification.
use_cuda=torch.cuda.is_available(): Utilizes GPU if available.
model.train_model(train_df): Trains the model on the training dataset.
5. Evaluating the Model
After training, we'll evaluate the model using the test set. We'll use metrics like F1 score and accuracy to assess the model's performance. This step is crucial to determine how well the model can classify new, unseen tweets into the correct categories.
def f1_multiclass(labels, preds):
return f1_score(labels, preds, average='micro')
result, model_outputs, wrong_predictions = model.eval_model(test_df, f1=f1_multiclass, acc=accuracy_score)
f1_multiclass: Defines a custom F1 score function for multi-class classification.
model.eval_model(test_df, f1=f1_multiclass, acc=accuracy_score): Evaluates the model on the test set using F1 score and accuracy.
6. Saving the Model
We'll save the trained model to a file for future use. This involves packing the model into a tar.gz file. Saving the model is important because it allows us to reuse it without retraining, saving time and computational resources.
def pack_model(model_path='', file_name=''):
files = [files for root, dirs, files in os.walk(model_path)][0]
with tarfile.open(file_name + '.tar.gz', 'w:gz') as f:
for file in files:
f.add(f'{model_path}/{file}')
pack_model('output_path', 'model_name')
pack_model: Packs the trained model into a tar.gz file for easy storage and transfer.
7. Loading the Model
In this step, we'll demonstrate how to load the saved model from the tar.gz file. This ensures that the model can be easily restored and used for prediction tasks without needing to retrain it.
def unpack_model(model_name=''):
tar = tarfile.open(f"{model_name}.tar.gz", "r:gz")
tar.extractall()
tar.close()
unpack_model('model_name')
unpack_model: Unpacks the tar.gz file to retrieve the model.
8. Using the Model for Prediction
Finally, we'll use the loaded model to make predictions on new tweets. This step shows how to apply the trained model to real-world data, enabling us to classify new instances and derive insights from them.
train_args = {
"reprocess_input_data": True,
"fp16": False,
"num_train_epochs": 4
}
model = ClassificationModel(
"bert", 'path_to_model/',
num_labels=4,
args=train_args
)
class_list = ['INSULT','ABUSE','PROFANITY','OTHER']
test_tweet1 = "Meine Mutter hat mir erzählt, dass mein Vater einen Wahlkreiskandidaten nicht gewählt hat, weil der gegen die Homo-Ehe ist"
predictions, raw_outputs = model.predict([test_tweet1])
print(class_list[predictions[0]])
# OTHER
test_tweet2 = "Frau #Böttinger meine Meinung dazu ist sie sollten uns mit ihrem Pferdegebiss nicht weiter belästigen #WDR"
predictions, raw_outputs = model.predict([test_tweet2])
print(class_list[predictions[0]])
# INSULT
model = ClassificationModel(...): Reloads the model for prediction.
model.predict: Predicts the class of new tweets.
Conclusion
This blog post walked you through a complete text classification project using BERT. By understanding each code block, you can adapt and expand this workflow to other languages or tasks. The power of BERT, combined with the simplicity of the simpletransformers library, makes it accessible for machine learning practitioners to develop state-of-the-art NLP models. Happy coding!