Embeddings are a way to represent text in a numerical format as vectors. They are used in a variety of applications, including search and retrieval, clustering, labeling data and anomaly detection.
Notebooks make it easy and fast to compute embeddings for a dataset on a GPU. If you want to follow along, you can checkout this notebook and run it in your own Oxen.ai account. When running this example, try an A10 GPU with 4GB of memory and 4 CPU cores. This will allow us to compute over 1,000 embeddings per second ๐ฅ
Setting Up The Interface
Marimo allows you to define UI elements that can be used to define the input repository, dataset, model name and number of rows to compute embeddings for. First lets setup a simple form that allows us to kick off the embedding computation.
Use the following code in your first cell to setup the UI.
import marimo as mo
model_name_ui = mo.ui.text(value="BAAI/bge-large-en-v1.5", full_width=True)
oxen_repo_name = mo.ui.text(value="ox/Simple-Wikipedia-50k", full_width=True)
oxen_dataset_name = mo.ui.text(value="train_0_50000.parquet", full_width=True)
num_rows = mo.ui.number(value=10000)
run_form = mo.md(
"""
Model Name
{model_name}
Repo Name
{oxen_repo_name}
File Name
{oxen_dataset_name}
Num Rows
{num_rows}
"""
).batch(
oxen_repo_name=oxen_repo_name,
oxen_dataset_name=oxen_dataset_name,
model_name=model_name_ui,
num_rows=num_rows
).form(
submit_button_label="Compute",
bordered=False,
show_clear_button=True,
clear_button_label="Reset"
)
run_form
To wait for the button to be clicked, use the mo.stop
function and check if the run_form.value
is None
.
# If the button is not pressed, stop execution
mo.stop(
run_form.value is None
)
Then download the data using the values from the form and the Remote Repo class.
from oxen import RemoteRepo
import pandas as pd
repo = RemoteRepo(oxen_repo_name.value)
repo.download(oxen_dataset_name.value, revision="main")
df = pd.read_parquet(oxen_dataset_name.value)
Compute Embeddings
This example will use the sentence_transformers
library to compute the embeddings with the default model as BAAI/bge-large-en-v1.5
. Find more information about the model here.
from sentence_transformers import SentenceTransformer
model_name = model_name_ui.value
print(f"Loading: {model_name}")
model = SentenceTransformer(model_name, device="cuda")
print(f"Model Loaded: {model_name}")
Now we can compute the embeddings for the dataset. We will compute them in batches to take full advantage of the GPU. In this example, we are just computing the embeddings for the title
column, but you can compute the embeddings for any text column in the dataset. mo.status.progress_bar
is used to show a progress bar in the UI as we compute the embeddings.
You should see the model computing over 1,000 embeddings per second ๐ฅ
# How many embeddings to compute at once
batch_size = 128
# Determine how many rows you want to process
rows_to_process = num_rows.value
# Copy the data frame
result_df = df.iloc[:rows_to_process].copy()
computed_embeddings = []
# Process the dataframe in batches
with mo.status.progress_bar(total=len(result_df)) as bar:
for i in range(0, len(result_df), batch_size):
if i % 10 == 0:
print(f"Computed {i} embeddings")
# Get the current batch
batch = result_df.iloc[i:i+batch_size]
# Extract texts from the batch
texts = batch['title'].tolist()
# Compute embeddings for the batch
batch_embeddings = model.encode(texts, normalize_embeddings=True)
# Add the batch embeddings to the overall list
computed_embeddings.extend(batch_embeddings)
bar.update(increment=batch_size)
# Add embeddings to the data frame
result_df['embedding'] = computed_embeddings
The embeddings will now be in the result_df
data frame in a new column called embedding
.
Save the Embeddings
Once you have computed the embeddings, save them to your Oxen.ai repository to share with your team. Oxen.ai will version the embeddings and allow you to track changes so that you can try out different models and configurations without worrying about losing your previous work.
def save_embeddings(df, username="YOUR_USERNAME", repo_name="YOUR_REPO_NAME", filename="embeddings.parquet", branch="main"):
# Connect to the remote repo
repo = RemoteRepo(f"{username}/{repo_name}")
# Checkout the branch
repo.create_checkout_branch(branch)
# Save data to disk
df.to_parquet(filename, index=False)
# Check if the file exists or has changed on the remote
if not repo.file_exists(filename) or repo.file_has_changes(filename):
# Stage/upload data to remote repository
repo.add(filename)
# Commit data with a message
repo.commit(f"Adding {filename}")
else:
print("File has no changes")
save_embeddings(result_df, username="YOUR_USERNAME", repo_name="YOUR_REPO_NAME", filename="embeddings.parquet", branch="embeddings")
Search Nearest Neighbors
To check how well the embeddings encode the text, letโs build a little search tool. We will use cosine_similarity
from sklearn
to build a simple nearest neighbor search.
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
def embedding_similarity(df, query_embedding, text_column='text', embedding_column='embedding', top_k=5):
# Make sure query_embedding is a 2D array for sklearn's cosine_similarity
if len(query_embedding.shape) == 1:
query_embedding = query_embedding.reshape(1, -1)
# Stack all embeddings from the DataFrame into a 2D array
embeddings_matrix = np.vstack(df[embedding_column].values)
# Calculate cosine similarity between query and all embeddings
similarities = cosine_similarity(query_embedding, embeddings_matrix).flatten()
# Create a results DataFrame with similarities
results_df = df.copy()
results_df['similarity_score'] = similarities
# Sort by similarity score (descending) and get top_k results
results_df = results_df.sort_values('similarity_score', ascending=False).head(top_k)
# Keep only text and similarity score for cleaner output
return results_df[[text_column, 'similarity_score']]
Now we can use the embedding_similarity
function to search for the nearest neighbors of a query.
search_term_embedding = model.encode(search_term_ui.value, normalize_embeddings=True)
embedding_similarity(result_df, search_term_embedding, text_column='title')
Build a text input so that we can enter any term we want and see similar titles.
search_term_ui = mo.ui.text(value="Denver Broncos", full_width=True)
mo.md(f"""
Enter any term to see it's neighbors
{search_term_ui}
""")