import os
import re

import nltk
import numpy as np
import polars as pl
from catboost import CatBoostClassifier
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
from sklearn.model_selection import train_test_split

import mlflow"averaged_perceptron_tagger")"punkt")"wordnet")"omw-1.4")"stopwords")
# Default url for MLFlow is "http://localhost:5000"
MLFLOW_TRACKING_URI = "http://localhost:5000"
DATASET_PATH = "../data/mlflow_data/toxic_comments.csv"
N_JOBS = -1
os.environ["AWS_ACCESS_KEY_ID"] = "mlflow"
os.environ["AWS_SECRET_ACCESS_KEY"] = "password"
os.environ["MLFLOW_S3_ENDPOINT_URL"] = ""
experiments = mlflow.search_experiments()
[<Experiment: artifact_location='file:///c:/dev/mlops_course/mlops_course/mlruns/0', creation_time=1715879720338, experiment_id='0', last_update_time=1715879720338, lifecycle_stage='active', name='Default', tags={}>]
df = pl.read_csv(DATASET_PATH, n_rows=50000)
shape: (3, 3)
text toxic
i64 str i64
0 "Explanation Why the edits made under my username Hardcore Metallica Fan were reverted? They weren't … 0
1 "D'aww! He matches this background colour I'm seemingly stuck with. Thanks. (talk) 21:51, January 11… 0
2 "Hey man, I'm really not trying to edit war. It's just that this guy is constantly removing relevant … 0
# Compile templates for re expressions
stop_words = set(stopwords.words("english"))
url_pattern = re.compile(r"https?://\S+|www\.\S+|\[.*?\]|[^a-zA-Z\s]+|\w*\d\w*")
spec_chars_pattern = re.compile("[0-9 \-_]+")
non_alpha_pattern = re.compile("[^a-z A-Z]+")

def text_preprocessing(input_text: str) -> str:
    text = input_text.lower()
    text = url_pattern.sub("", text)
    text = spec_chars_pattern.sub(" ", text)
    text = non_alpha_pattern.sub(" ", text)
    text = " ".join(word for word in text.split() if word not in stop_words)
    return text.strip()

df = df.with_columns(pl.col("text").map_elements(text_preprocessing).str.split(" ").alias("corpus"))

shape: (5, 4)
text toxic corpus
i64 str i64 list[str]
0 "Explanation Why the edits made under my username Hardcore Metallica Fan were reverted? They weren't … 0 ["explanation", "edits", … "retired"]
1 "D'aww! He matches this background colour I'm seemingly stuck with. Thanks. (talk) 21:51, January 11… 0 ["daww", "matches", … "utc"]
2 "Hey man, I'm really not trying to edit war. It's just that this guy is constantly removing relevant … 0 ["hey", "man", … "info"]
3 "" More I can't make any real suggestions on improvement - I wondered if the section statistics shoul… 0 ["cant", "make", … "wikipediagoodarticlenominationstransport"]
4 "You, sir, are my hero. Any chance you remember what page that's on?" 0 ["sir", "hero", … "thats"]
def lemmatize(input_frame: pl.DataFrame) -> pl.DataFrame:
    lemmatizer = WordNetLemmatizer()

    return input_frame.with_columns(
            lambda input_list: [lemmatizer.lemmatize(token) for token in input_list]

df = lemmatize(df)
shape: (5, 4)
text toxic corpus
i64 str i64 list[str]
0 "Explanation Why the edits made under my username Hardcore Metallica Fan were reverted? They weren't … 0 ["explanation", "edits", … "retired"]
1 "D'aww! He matches this background colour I'm seemingly stuck with. Thanks. (talk) 21:51, January 11… 0 ["daww", "match", … "utc"]
2 "Hey man, I'm really not trying to edit war. It's just that this guy is constantly removing relevant … 0 ["hey", "man", … "info"]
3 "" More I can't make any real suggestions on improvement - I wondered if the section statistics shoul… 0 ["cant", "make", … "wikipediagoodarticlenominationstransport"]
4 "You, sir, are my hero. Any chance you remember what page that's on?" 0 ["sir", "hero", … "thats"]
tf_idf_params = {"max_features": 10000, "analyzer": "word"}

tf_idf_vectorizer = TfidfVectorizer(**tf_idf_params)

train, test = train_test_split(
    df, test_size=0.25, random_state=RANDOM_STATE, stratify=df["toxic"], shuffle=True
train_features = tf_idf_vectorizer.transform(train["corpus"].to_pandas().astype(str))
test_features = tf_idf_vectorizer.transform(test["corpus"].to_pandas().astype(str))
def conf_matrix(y_true: np.ndarray, pred: np.ndarray) -> Figure:
    fig, ax = plt.subplots(figsize=(5, 5))
    ConfusionMatrixDisplay.from_predictions(y_true, pred, ax=ax, colorbar=False)
    _ = ax.set_title("Confusion Matrix")
    return fig
embeddings_experiment = mlflow.set_experiment("TF_IDF")
run_name = "logistic_regression"

with mlflow.start_run(run_name=run_name) as run:
    lr_model_params = {
        "multi_class": "multinomial",
        "solver": "saga",
        "random_state": RANDOM_STATE,

    model_lr = LogisticRegression(**lr_model_params), train["toxic"])
    predicts = model_lr.predict(test_features)

    metrics = {
        "accuracy": accuracy_score(test["toxic"], predicts),
        "recall": recall_score(test["toxic"], predicts),
        "precision": precision_score(test["toxic"], predicts),
        "roc_auc_score": roc_auc_score(test["toxic"], predicts),
    for metric_name, metric_value in metrics.items():
        mlflow.log_metric(metric_name, metric_value)

    mlflow.sklearn.log_model(sk_model=model_lr, artifact_path=f"mlflow/{run_name}/model")

    fig = conf_matrix(test["toxic"], predicts)
    mlflow.log_figure(figure=fig, artifact_file=f"{run_name}_confusion_matrix.png")

run_name = "catboost"

with mlflow.start_run(run_name=run_name) as run:
    catboost_model_params = {
        "random_state": 42,
        "learning_rate": 0.001,
        "auto_class_weights": "Balanced",
        "verbose": False,
        "n_estimators": 500,

    model_catboost = CatBoostClassifier(**catboost_model_params), np.array(train["toxic"]))
    predicts = model_catboost.predict(test_features)

    metrics = {
        "accuracy": accuracy_score(np.array(test["toxic"]), predicts),
        "recall": recall_score(np.array(test["toxic"]), predicts),
        "precision": precision_score(np.array(test["toxic"]), predicts),
        "roc_auc_score": roc_auc_score(np.array(test["toxic"]), predicts),
    for metric_name, metric_value in metrics.items():
        mlflow.log_metric(metric_name, metric_value)

    mlflow.catboost.log_model(cb_model=model_catboost, artifact_path=f"mlflow/{run_name}/model")

    fig = conf_matrix(test["toxic"], predicts)
    mlflow.log_figure(figure=fig, artifact_file=f"{run_name}_confusion_matrix.png")
