Workspace
Juan Jose Ortiz/

FashionWear: zero-shot classification and transformer models.

0
Beta
Spinner

FashionWear: zero-shot classification and transformer models.

Here I will show how to do transfer learning with a pre-trained language (PLM) model to classify fashion items. Specifically, we run a PLM with a zero-shot classification task to infer fashion categories from product names.

Zero, single and few-shot classification seem to be an emergent feature of large language models. In zero shot classification, we provide the model with a prompt and a sequence of text or labels that describe what we want our model to do, in natural language. For more information on zero-shot classification visit: https://huggingface.co/tasks/zero-shot-classification

The selected model for this task is BART, a large pre-trained model developed by facebook. In short, BART is a transformer encoder-decoder (seq2seq) model with a bidirectional (BERT-like) encoder and an autoregressive (GPT-like) decoder. Because our dataset includes original categories, we can evaluate how well the model performs.

by jortega

# load libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from transformers import pipeline
from datasets import Dataset
# load data
df = pd.read_csv('./styles.csv', on_bad_lines = 'skip')
df.info()

Data pre-processing

Because this is a very large dataset we will take sample of records and subset the columns we are interested in. In addition, we will drop any records with missing data.

# take sample of 2000
dfs = df.sample(n = 2000, random_state = 253)
# check nulls
count_null = dfs.isna().sum()
count_null[count_null>0]
# drop records with missing values
dfs = dfs.dropna()

# grab the columns we need
cc = ["id", "masterCategory", "subCategory", "productDisplayName"]
dfs = dfs[cc]
dfs[["productDisplayName", "masterCategory"]].head()

The objective is to infer the master category based ONLY on the product name. Below are a couple of examples to illustrate this.

Ex 1:

  • Original category (masterCategory): _ Footwear_
  • Product display name: _Rocia Women Black & Brown Sandals_
  • Inferred category: ?

Ex 2:

  • Original category (masterCategory): _Accesories_
  • Product display name: _Lino Perros Women Orange Backpacks_
  • Inferred category: ?

In the end we want to find out how well the PLM performs. We do this by comparing the original category vs the inferred one.

# check original categories
dfs.groupby("masterCategory")["id"].count().sort_values(ascending = False)
# remove categories with less than ten examples
mask = (dfs["masterCategory"] == "Sporting Goods") | (dfs["masterCategory"] == "Home") | (dfs["masterCategory"] == "Free Items")
dfs = dfs[~mask]

# convert dataframe to a dataset for optimized data processing
# convert to dataset
dt = Dataset.from_pandas(dfs)
print(dt.column_names, "\n", f"Number of cases: {dt.num_rows}")

Run zero-shot classification with BART, a pre-trained language model (PLM)

Visit https://huggingface.co/ for more details on transformers and PLM models.

# from transformers, initialize a classifier
classifier = pipeline(model="facebook/bart-large-mnli")

# get labels for zero-shot classification
labels = dfs["masterCategory"].unique()
# define function to run zero-shot classification in batch mode
def get_class(batch):
    fashion_items = batch["productDisplayName"]
    output = classifier(fashion_items, candidate_labels = labels)
    return {"LLM_CATEGORY" : output}
# classify fasion wear
dt = dt.map(get_class, batched = True, batch_size = 10)
# get inferred class:
def max_score(output):
    max_label = output["labels"][0]
    max_score = np.round(output["scores"][0], 3)
    return {"label:" :max_label, "score" : max_score}

# check a sample of results 
subset = dt.shuffle(seed = 42).select(range(4))

for idx in range(4):
    print(f'Product: {subset[idx]["productDisplayName"]}') 
    print(f'Actual Category: {subset[idx]["masterCategory"]}')
    print(f'Inferred Category: {max_score(subset[idx]["LLM_CATEGORY"])}')
    print("\n")



  • AI Chat
  • Code