Zero-shot learning, or ZSL, is a machine learning process commonly used for Natural Language Processing that allows you to generate predictions on unseen data without the need to train a model. Essentially, zero-shot learning gives you immensely powerful models that have been trained on enormous datasets and work out-of-the-box. It’s an incredible technology.
Zero-shot learning models come in various forms, but in this example we’ll be looking at the use of zero-shot learning for text classification. The Bart model we’re using in this project came out in 2019 and was developed at Facebook.
It shares some similarities with other state-of-the-art models, such as BERT and GPT, and can also be fine-tuned on different datasets to allow it to perform different tasks. Bart can both comprehend text and generate new text of its own, and it excels at text classification.
In this project, we’ll use the Bart MNLI model for text classification, by configuring it to classify customer service emails sent to an ecommerce website. This is a really important, but sadly often overlooked task in online retailing. Customer service has a huge impact on the customer experience, and it’s closely linked to customer retention and churn.
What typically happens is that customer service staff get so bogged down in dealing with their workloads, that they fail to take a step back and examine why customers are making contact, and collaborating with other departments to help them reduce their workloads, make the business more efficient, and crucially, make things better for customers.
First, open up a Jupyter notebook and import the pandas
package and the pipeline
module from transformers
. This will let you run the pre-built models made available by the superb Hugging Face Transformers project, all with minimal code.
import pandas as pd
from transformers import pipeline
I’m using an actual dataset from a real ecommerce retailer here. The dataset comprises over 25K emails to the company’s customer service help desk, and comprises their unique ID, a code defining how the customer classified their ticket, and the body of the email itself.
Although these tickets have been pre-classified by customers, customers often don’t classify them correctly. The tickets contain far more information than the small number of ticket categories available, so we can use Bart to reclassify them accordingly based on the content of the emails and extract more knowledge.
df = pd.read_csv('final.csv')
df = df[['uid', 'department_code', 'message']]
df = df.head(1000)
df.head(3)
uid | department_code | message | |
---|---|---|---|
0 | P8462 | oq1 | It looks like you've misspelled the word "Bril... |
1 | P8463 | other1 | Hi,\r\n\r\nI checked your website www... |
2 | P8464 | oq1 | Are you looking for effective online promotion... |
Next, we’ll use the pipeline()
function to load the zero-shot-classification
transformer and set it to use the facebook/bart-large-mnli
model. This has been trained on a truly massive amount of data so already understands loads about the structure of text.
The model weighs in at around 1.63GB, so is quite a whopper. Although it will run happily on a CPU, models like this are best run on GPUs, and will then generate results far more quickly. To get the model to run on our NVIDIA GPU using CUDA, we can pass in the device=0
argument when calling pipeline()
. However, you will need a powerful GPU to allow a model this large to run.
classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli')
We’ll be using Bart for multi-class text classification. To perform this task, we need to provide it with a list of classes, and it will figure out which emails should be assigned to which class. It does this by presenting each candidate label as a “hypothesis” to the model, with the sequence text representing the “premise”.
To come up with the class labels, I did some Exploratory Data Analysis on the raw data to identify the common trends, which identified a number of issues that commonly arise, but which aren’t covered in the dropdown menu customers select from when sending in their support email.
classes = [
'Orders: Order tracking',
'Orders: Order delayed / Order not arrived',
'Orders: Delivery price',
'Orders: Damaged on arrival',
'Orders: Order cancelation',
'Orders: Can I change my order?'
'Returns: Can I return an item?',
'Returns: When will my return be completed?',
'Returns: When will my refund be processed?',
'Products: Do you stock this product?',
'Products: How do I use this product?',
'Products: Which product should I buy?',
'Products: Can you install a product for me?'
'Site: Can I get a password reset?',
'Other',
'Spam'
]
To see how the model works, we’ll do a quick dry run. We’ll fetch the message
text from one of the emails in the dataframe and pass it to the classifier()
function, along with the classes
list, and an argument to set multi_class
to False
.
By examining the result
returned by the model, we can see that classifier()
has returned the list of labels, and their corresponding scores, ranking each class in order of probability of being a match for the email message
. It correctly figured out that this message was in the spam
class, which is quite incredible.
text = df['message'][0]
result = classifier(text, classes, multi_class=False)
result
{'sequence': 'It looks like you\'ve misspelled the word "Brillant" on your website. I thought you would like to know :). Silly mistakes can ruin your site\'s credibility. I\'ve used a tool called SpellScan.com in the past to keep mistakes off of my website.\r\n\r\n-Kerri',
'labels': ['Spam',
'Other',
'Products: How do I use this product?',
'Returns: When will my return be completed?',
'Orders: Order cancelation',
'Orders: Can I change my order?Returns: Can I return an item?',
'Returns: When will my refund be processed?',
'Orders: Delivery price',
'Products: Do you stock this product?',
'Products: Which product should I buy?',
'Orders: Order tracking',
'Orders: Order delayed / Order not arrived',
'Orders: Damaged on arrival',
'Products: Can you install a product for me?Site: Can I get a password reset?'],
'scores': [0.4181201457977295,
0.0918767899274826,
0.05920625850558281,
0.05549466982483864,
0.05338582396507263,
0.05128355324268341,
0.04088611900806427,
0.040571365505456924,
0.03962726145982742,
0.039201993495225906,
0.03557770699262619,
0.026288021355867386,
0.02581968531012535,
0.022660594433546066]}
To repeat this process for every email in our dataset we can create a lambda
function. We’ll use the apply()
function to run classifier()
via lambda
and will pass in the message
value for each row, the classes
list, and the multi_class
value we want to use.
It will then classify each email and write the lists returned by the models to a new dataframe column called labels
. This will take quite a while to run. You may want to test it on a smaller subset of your data first. After struggling to get Bart running on my 6GB NVIDIA RTX 2070 GPU, I eventually left it running on the CPU overnight instead.
df['labels'] = df.apply(lambda x: classifier(x.message, classes, multi_class=True), axis=1)
df.head()
df.to_csv('results.csv')
uid | department_code | message | labels | predicted_category | score | |
---|---|---|---|---|---|---|
0 | P8462 | oq1 | It looks like you've misspelled the word "Bril... | {'sequence': 'It looks like you've misspelled ... | Spam | 0.218138 |
1 | P8463 | other1 | Hi,\r\n\r\nI checked your website www... | {'sequence': 'Hi, I checked your website ww... | Products: How do I use this product? | 0.299080 |
2 | P8464 | oq1 | Are you looking for effective online promotion... | {'sequence': 'Are you looking for effective on... | Spam | 0.325764 |
3 | P8465 | oq1 | Hi, order XXXXXXX placed on the 13th and no... | {'sequence': 'Hi, order XXXXXXX placed on t... | Orders: Order delayed / Order not arrived | 0.995220 |
4 | P8466 | pq3 | Ref 526LD299814\r\nCan you please instruct the... | {'sequence': 'Ref XXXXXXX Can you please ... | Orders: Order delayed / Order not arrived | 0.920468 |
I’ll cover how I analysed the model outputs, and some other customer experience related data, in a separate article. However, here’s a quick breakdown of what the text classification revealed.
As you’ll see below in this sample, a few operational issues (damages, late despatches, and a lack of order tracking) cause the bulk of customer service issues. Fix these, and the CS workload will quickly drop, and customers will be more likely to be happy and be retained.
df['predicted_category'] = df.apply(lambda row: row['labels']['labels'][0], axis=1)
df['score'] = df.apply(lambda row: row['labels']['scores'][0], axis=1)
df['predicted_category'].value_counts()
Orders: Order delayed / Order not arrived 291
Orders: Order tracking 154
Products: Do you stock this product? 138
Orders: Damaged on arrival 128
Products: How do I use this product? 79
Other 65
Products: Which product should I buy? 33
Orders: Delivery price 29
Orders: Order cancelation 26
Orders: Can I change my order?Returns: Can I return an item? 22
Returns: When will my return be completed? 17
Returns: When will my refund be processed? 16
Spam 2
Name: predicted_category, dtype: int64
Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mohamed, A., Levy, O., Stoyanov, V. and Zettlemoyer, L., 2019. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. arXiv preprint arXiv:1910.13461.
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L. and Polosukhin, I., 2017. Attention is all you need. arXiv preprint arXiv:1706.03762.
Matt Clarke, Sunday, March 14, 2021