Embeddings-UBalt / emb_sim.py
AlanFeder's picture
embed all at once
92802a1 verified
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from matplotlib.colors import LinearSegmentedColormap
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI, AuthenticationError, RateLimitError
from dotenv import load_dotenv
import os
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
oai_client = OpenAI(api_key=openai_api_key)
def calculate_embeddings(words):
# Get word embeddings
response = oai_client.embeddings.create(input=words, model="text-embedding-3-small")
embeddings = [e.embedding for e in response.data]
return embeddings
def process_array(arr):
# Ensure the input is a square array
if arr.shape[0] != arr.shape[1]:
raise ValueError("Input must be a square array")
n = arr.shape[0]
# Step 1: Keep only the upper triangle (excluding diagonal)
upper_triangle = np.triu(arr, k=1)
# Step 2: Reverse horizontally
reversed_upper_triangle = np.fliplr(upper_triangle)
# Step 3: Drop the final row and column
result = reversed_upper_triangle[:-1, :-1]
# Step 4: Mask the zeros
masked_result = np.ma.masked_where(result == 0, result)
return masked_result
def plot_heatmap(masked_result, l1: list[str]):
n, _ = masked_result.shape
# Create the heatmap
fig, ax = plt.subplots(
figsize=(12, 10)
) # Increased figure size for better visibility
# Create a custom colormap
colors = ["darkred", "lightgray", "dodgerblue"]
n_bins = 100
cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins)
cmap.set_bad("white") # Set color for masked values (zeros) to white
# Plot the heatmap
im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1)
# Add text annotations
for i in range(n):
for j in range(n):
if not np.ma.is_masked(masked_result[i, j]):
text = ax.text(
j,
i,
f"{masked_result[i, j]:.2f}",
ha="center",
va="center",
color="black",
)
# Set y and x axis labels
ax.set_yticks(range(n))
ax.set_yticklabels(l1[:-1])
ax.set_xticks(range(n))
ax.set_xticklabels(reversed(l1[1:]))
# Move x-axis to the top
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
# Rotate x-axis labels for better readability
plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
# Add colorbar
cbar = plt.colorbar(im)
cbar.set_ticks([-1, 0, 1])
cbar.set_ticklabels(["-1", "0", "1"])
# Add title
plt.title("Correlation Heatmap", pad=20)
# Adjust layout and display the plot
plt.tight_layout()
return fig
def plot_pca(embeddings, words):
fig, ax = plt.subplots(figsize=(12, 10))
pca = PCA(n_components=2)
embeddings_2d = pca.fit_transform(embeddings)
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1])
for i, word in enumerate(words):
ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1]))
ax.set_title("PCA of Word Embeddings")
ax.set_xlabel("First Principal Component")
ax.set_ylabel("Second Principal Component")
plt.tight_layout()
return fig
def word_similarity_heatmap(input_text):
words = [word.strip() for word in input_text.split(",")]
if len(words) < 2:
return "Please enter at least two words."
try:
embeddings = calculate_embeddings(words)
similarities = cosine_similarity(embeddings)
new_array = process_array(similarities)
heatmap = plot_heatmap(new_array, words)
pca_plot = plot_pca(embeddings, words)
return heatmap, pca_plot
# return heatmap
except AuthenticationError as e:
print("OpenAI API key is invalid. Please check your API key.")
raise e
except RateLimitError as e:
print("OpenAI API rate limit exceeded. Please try again later.")
raise e
except Exception as e:
print(f"An error occurred: {str(e)}")
raise e
iface = gr.Interface(
fn=word_similarity_heatmap, # _and_pca,
inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"),
outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")],
title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings",
description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.",
)
# Launch the app
iface.launch(share=True)