mobrown's picture
Update app.py
7702cd1 verified
import numpy as np
from sklearn.decomposition import PCA
import gensim.downloader as api
import gradio as gr
import plotly.graph_objects as go
# Load the Word2Vec model
model = api.load("word2vec-google-news-300")
def gensim_analogy(model, word1, word2, word3):
try:
result = model.most_similar(positive=[word2, word3], negative=[word1], topn=1)
return result[0][0] # Return the word
except KeyError as e:
return str(e)
def plot_words_plotly(model, words):
vectors = np.array([model[word] for word in words if word in model.key_to_index])
# Reduce dimensions to 2D for plotting
pca = PCA(n_components=2)
vectors_2d = pca.fit_transform(vectors)
# Create a scatter plot
fig = go.Figure()
# Add scatter points for each word vector
for word, vec in zip(words, vectors_2d):
fig.add_trace(go.Scatter(x=[vec[0]], y=[vec[1]],
text=[word], mode='markers+text',
textposition="bottom center",
name=word))
fig.update_layout(title="Visualization of Word Vectors",
xaxis_title="PCA 1",
yaxis_title="PCA 2",
showlegend=True,
width=600, # Adjust width as needed
height=400) # Adjust height as needed
return fig
def gradio_interface(choice, custom_input):
if choice == "Custom":
if not custom_input or len(custom_input.split(", ")) != 3:
return "Invalid input. Please enter exactly three words, separated by commas.", None, {
"error": "Invalid input"}
words = custom_input.split(", ")
else:
if not choice:
return "Invalid input. Please select or enter words.", None, {
"error": "Invalid input"}
words = choice.split(", ")
word1, word2, word3 = words
word4 = gensim_analogy(model, word1, word2, word3)
plot_fig = plot_words_plotly(model, [word1, word2, word3, word4])
if word4 in model.key_to_index:
vector = model[word4]
vector_display = f"{word4}: {np.round(vector, 2).tolist()}"
else:
vector_display = "Vector not available for the resulting word"
return word4, plot_fig, vector_display
choices = [
"man, king, woman",
"Paris, France, London",
"strong, stronger, weak",
"pork, pig, beef",
"Custom"
]
def clear_inputs():
return "", "", "", "", None
# Define the layout using Rows and Columns
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
gr.Markdown("# Word Analogy and Vector Visualization")
gr.Markdown(
"Select a predefined triplet of words or choose 'Custom' and enter your own (comma-separated) to find a fourth word by analogy, and see their vectors plotted with Plotly.")
radio = gr.Radio(choices=choices, label="Choose predefined words or enter custom words")
custom_words = gr.Textbox(
label="Custom words (comma-separated, required for custom choice; use only if 'Custom' is selected)",
placeholder="Enter 3 words separated by commas")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit")
output_word = gr.Textbox(label="Output Word")
word_plot = gr.Plot(label="Word Vectors Visualization")
with gr.Row():
word_vectorization = gr.Textbox(label="Vectorization of the Output Word", lines=4, max_lines=4)
clear_btn.click(fn=clear_inputs, inputs=None,
outputs=[radio, custom_words, output_word, word_vectorization, word_plot])
submit_btn.click(fn=gradio_interface, inputs=[radio, custom_words],
outputs=[output_word, word_plot, word_vectorization])
iface.launch(share=True)