Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
from matplotlib.colors import LinearSegmentedColormap | |
from sklearn.metrics.pairwise import cosine_similarity | |
from openai import OpenAI, AuthenticationError, RateLimitError | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
oai_client = OpenAI(api_key=openai_api_key) | |
def calculate_embeddings(words): | |
# Get word embeddings | |
response = oai_client.embeddings.create(input=words, model="text-embedding-3-small") | |
embeddings = [e.embedding for e in response.data] | |
return embeddings | |
def process_array(arr): | |
# Ensure the input is a square array | |
if arr.shape[0] != arr.shape[1]: | |
raise ValueError("Input must be a square array") | |
n = arr.shape[0] | |
# Step 1: Keep only the upper triangle (excluding diagonal) | |
upper_triangle = np.triu(arr, k=1) | |
# Step 2: Reverse horizontally | |
reversed_upper_triangle = np.fliplr(upper_triangle) | |
# Step 3: Drop the final row and column | |
result = reversed_upper_triangle[:-1, :-1] | |
# Step 4: Mask the zeros | |
masked_result = np.ma.masked_where(result == 0, result) | |
return masked_result | |
def plot_heatmap(masked_result, l1: list[str]): | |
n, _ = masked_result.shape | |
# Create the heatmap | |
fig, ax = plt.subplots( | |
figsize=(12, 10) | |
) # Increased figure size for better visibility | |
# Create a custom colormap | |
colors = ["darkred", "lightgray", "dodgerblue"] | |
n_bins = 100 | |
cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins) | |
cmap.set_bad("white") # Set color for masked values (zeros) to white | |
# Plot the heatmap | |
im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1) | |
# Add text annotations | |
for i in range(n): | |
for j in range(n): | |
if not np.ma.is_masked(masked_result[i, j]): | |
text = ax.text( | |
j, | |
i, | |
f"{masked_result[i, j]:.2f}", | |
ha="center", | |
va="center", | |
color="black", | |
) | |
# Set y and x axis labels | |
ax.set_yticks(range(n)) | |
ax.set_yticklabels(l1[:-1]) | |
ax.set_xticks(range(n)) | |
ax.set_xticklabels(reversed(l1[1:])) | |
# Move x-axis to the top | |
ax.xaxis.tick_top() | |
ax.xaxis.set_label_position("top") | |
# Rotate x-axis labels for better readability | |
plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor") | |
# Add colorbar | |
cbar = plt.colorbar(im) | |
cbar.set_ticks([-1, 0, 1]) | |
cbar.set_ticklabels(["-1", "0", "1"]) | |
# Add title | |
plt.title("Correlation Heatmap", pad=20) | |
# Adjust layout and display the plot | |
plt.tight_layout() | |
return fig | |
def plot_pca(embeddings, words): | |
fig, ax = plt.subplots(figsize=(12, 10)) | |
pca = PCA(n_components=2) | |
embeddings_2d = pca.fit_transform(embeddings) | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1]) | |
for i, word in enumerate(words): | |
ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1])) | |
ax.set_title("PCA of Word Embeddings") | |
ax.set_xlabel("First Principal Component") | |
ax.set_ylabel("Second Principal Component") | |
plt.tight_layout() | |
return fig | |
def word_similarity_heatmap(input_text): | |
words = [word.strip() for word in input_text.split(",")] | |
if len(words) < 2: | |
return "Please enter at least two words." | |
try: | |
embeddings = calculate_embeddings(words) | |
similarities = cosine_similarity(embeddings) | |
new_array = process_array(similarities) | |
heatmap = plot_heatmap(new_array, words) | |
pca_plot = plot_pca(embeddings, words) | |
return heatmap, pca_plot | |
# return heatmap | |
except AuthenticationError as e: | |
print("OpenAI API key is invalid. Please check your API key.") | |
raise e | |
except RateLimitError as e: | |
print("OpenAI API rate limit exceeded. Please try again later.") | |
raise e | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
raise e | |
iface = gr.Interface( | |
fn=word_similarity_heatmap, # _and_pca, | |
inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"), | |
outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")], | |
title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings", | |
description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.", | |
) | |
# Launch the app | |
iface.launch(share=True) | |