from bigdl.orca import init_orca_context, stop_orca_context
import tensorflow as tf
import numpy as np
import pandas as pd

init_orca_context(cluster_mode="local", cores=4, memory="10g")

df = pd.read_csv(
    "/home/ubuntu/Downloads/spark-20231031T155920Z-001/spark/mail_data.csv",
    index_col=False
)
df.head()

df["Category"] = df["Category"].map({"ham": 0, "spam": 1})


from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    df.Message, df.Category, test_size=0.2, random_state=10
)
X_train.shape, X_test.shape

b = df['Category'].to_numpy()
a = df['Message'].to_numpy()


from tensorflow.keras.preprocessing.text import Tokenizer

# vectorizing text, turning each text into sequence of integers
tokenizer = Tokenizer(lower=True, oov_token="<OOV>")
tokenizer.fit_on_texts(X_train)

X_train = tokenizer.texts_to_sequences(X_train)
X_test = tokenizer.texts_to_sequences(X_test)

VOCAB_SIZE = len(tokenizer.word_index) + 1  # + <OOV>
MAXLEN = max([len(x) for x in X_train])

print("Vocabulary size:", VOCAB_SIZE)
print("Maximum length:", MAXLEN)


from tensorflow.keras.preprocessing.sequence import pad_sequences

X_train = pad_sequences(X_train, maxlen=MAXLEN, padding="pre")
X_test = pad_sequences(X_test, maxlen=MAXLEN, padding="pre")


from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# Hyperparameters
EMBEDDING_SIZE = 128
BATCH_SIZE = 64
NUM_EPOCHS = 20


# Design the model Archeticture
def model_creator(config):
    model = Sequential(
        [
            Embedding(VOCAB_SIZE, EMBEDDING_SIZE, input_length=MAXLEN, mask_zero=True),
            LSTM(256, recurrent_dropout=0.2),
            Dropout(0.2),
            Dense(1, activation="sigmoid"),  # Spam or Ham
        ]
    )

    # Compile model
    model.compile(loss="binary_crossentropy", optimizer="nadam", metrics=["accuracy"])
    return model


from bigdl.orca.learn.tf2 import Estimator

est = Estimator.from_keras(model_creator=model_creator, workers_per_node=1)
batch_size = 320

stats = est.fit(
    X_train,
    epochs=5,
    batch_size=batch_size,
    validation_data=test_dataset
)
