Spaces:
Sleeping
Sleeping
FoodDesert
commited on
Commit
•
b4bf2a9
1
Parent(s):
503dc78
Upload 2 files
Browse files- app.py +81 -26
- tf_idf_files_418.joblib +3 -0
app.py
CHANGED
@@ -21,6 +21,7 @@ import os
|
|
21 |
import glob
|
22 |
import itertools
|
23 |
from itertools import islice
|
|
|
24 |
|
25 |
|
26 |
|
@@ -159,6 +160,26 @@ def remove_special_tags(original_string):
|
|
159 |
removed_tags = [tag for tag in tags if tag in special_tags]
|
160 |
return ", ".join(remaining_tags), removed_tags
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
# Load the model and data once at startup
|
164 |
with h5py.File('complete_artist_data.hdf5', 'r') as f:
|
@@ -204,6 +225,24 @@ with open("word_rating_probabilities.csv", 'r', newline='', encoding='utf-8') as
|
|
204 |
nsfw_tags.add(word)
|
205 |
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
sample_images_directory_path = 'sampleimages'
|
208 |
def generate_artist_image_tuples(top_artists, image_directory):
|
209 |
json_files = glob.glob(f'{image_directory}/*.json')
|
@@ -404,6 +443,7 @@ def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded):
|
|
404 |
|
405 |
# Return the vector as a 2D array for compatibility with SVD transform
|
406 |
return pseudo_vector.reshape(1, -1)
|
|
|
407 |
|
408 |
def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
409 |
# Compute cosine similarities
|
@@ -415,35 +455,42 @@ def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
|
415 |
# Return the top N indices
|
416 |
return sorted_indices
|
417 |
|
|
|
418 |
def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
# Remaining part of the function
|
431 |
-
pseudo_vector = construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded)
|
432 |
-
reduced_pseudo_vector = svd_loaded.transform(pseudo_vector)
|
433 |
-
# Compute cosine similarities
|
434 |
-
similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix_loaded).flatten()
|
435 |
|
436 |
-
#
|
437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
-
# Create the initial tag_similarity_dict
|
440 |
-
tag_similarity_dict = {list(tag_to_row_loaded.keys())[i]: similarities[i] for i in top_indices_reduced}
|
441 |
if not allow_nsfw_tags:
|
442 |
-
tag_similarity_dict = {tag:
|
443 |
|
|
|
|
|
|
|
444 |
sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
|
|
|
|
|
|
|
|
|
445 |
|
446 |
-
return
|
447 |
|
448 |
|
449 |
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
@@ -555,6 +602,7 @@ def build_tag_offsets_dicts(new_image_tags_with_positions):
|
|
555 |
# Modify the tag
|
556 |
modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
|
557 |
artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
|
|
|
558 |
# Calculate the end position based on the original tag length
|
559 |
end_pos = start_pos + len(tag_text)
|
560 |
# Append the structured data for each tag
|
@@ -564,6 +612,7 @@ def build_tag_offsets_dicts(new_image_tags_with_positions):
|
|
564 |
"end_pos": end_pos,
|
565 |
"modified_tag": modified_tag,
|
566 |
"artist_matrix_tag": artist_matrix_tag,
|
|
|
567 |
"node_type": nodetype
|
568 |
})
|
569 |
return tag_data
|
@@ -619,8 +668,13 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
|
|
619 |
suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
620 |
|
621 |
suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
|
622 |
-
suggested_tags = get_tfidf_reduced_similar_tags([item["
|
623 |
-
|
|
|
|
|
|
|
|
|
|
|
624 |
topnsuggestions = list(islice(suggested_tags_filtered.items(), 100))
|
625 |
suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
626 |
|
@@ -658,8 +712,9 @@ with gr.Blocks(css=css) as app:
|
|
658 |
#gr.Image(label=" ", value=image_path, height=155, width=140)
|
659 |
#gr.HTML('<div style="text-align: center;"><img src={image_path} alt="Cute Mascot" style="max-height: 100px; background: transparent;"></div><br>')
|
660 |
#gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
|
661 |
-
image_path = os.path.join('mascotimages', "transparentsquirrel.png")
|
662 |
-
|
|
|
663 |
gr.Image(value=img,show_label=False, show_download_button=False, show_share_button=False, height=200)
|
664 |
submit_button = gr.Button(variant="primary")
|
665 |
with gr.Row():
|
|
|
21 |
import glob
|
22 |
import itertools
|
23 |
from itertools import islice
|
24 |
+
from pathlib import Path
|
25 |
|
26 |
|
27 |
|
|
|
160 |
removed_tags = [tag for tag in tags if tag in special_tags]
|
161 |
return ", ".join(remaining_tags), removed_tags
|
162 |
|
163 |
+
|
164 |
+
# Define a function to load all necessary components
|
165 |
+
def load_model_components(file_path):
|
166 |
+
# Ensure the file path is a Path object for robust path handling
|
167 |
+
file_path = Path(file_path)
|
168 |
+
|
169 |
+
# Check if the file exists
|
170 |
+
if not file_path.is_file():
|
171 |
+
raise FileNotFoundError(f"The specified joblib file was not found: {file_path}")
|
172 |
+
|
173 |
+
# Load all the model components from the joblib file
|
174 |
+
model_components = joblib.load(file_path)
|
175 |
+
|
176 |
+
# Create a reverse mapping from row index to tag
|
177 |
+
if 'tag_to_row_index' in model_components:
|
178 |
+
model_components['row_to_tag'] = {idx: tag for tag, idx in model_components['tag_to_row_index'].items()}
|
179 |
+
|
180 |
+
return model_components
|
181 |
+
# Load all components at the start
|
182 |
+
tf_idf_components = load_model_components('tf_idf_files_418.joblib')
|
183 |
|
184 |
# Load the model and data once at startup
|
185 |
with h5py.File('complete_artist_data.hdf5', 'r') as f:
|
|
|
225 |
nsfw_tags.add(word)
|
226 |
|
227 |
|
228 |
+
# Read the set of valid artists into memory.
|
229 |
+
artist_set = set()
|
230 |
+
with open("fluffyrock_3m.csv", 'r', newline='', encoding='utf-8') as csvfile:
|
231 |
+
"""
|
232 |
+
Load artist names from a CSV file and store them in the global set.
|
233 |
+
Artist tags start with 'by_' and the prefix will be removed.
|
234 |
+
"""
|
235 |
+
reader = csv.reader(csvfile)
|
236 |
+
for row in reader:
|
237 |
+
tag_name = row[0] # Assuming the first column contains the tag names
|
238 |
+
if tag_name.startswith('by_'):
|
239 |
+
# Strip 'by_' from the start of the tag name and add to the set
|
240 |
+
artist_name = tag_name[3:] # Remove the first three characters 'by_'
|
241 |
+
artist_set.add(artist_name)
|
242 |
+
def is_artist(name):
|
243 |
+
return name in artist_set
|
244 |
+
|
245 |
+
|
246 |
sample_images_directory_path = 'sampleimages'
|
247 |
def generate_artist_image_tuples(top_artists, image_directory):
|
248 |
json_files = glob.glob(f'{image_directory}/*.json')
|
|
|
443 |
|
444 |
# Return the vector as a 2D array for compatibility with SVD transform
|
445 |
return pseudo_vector.reshape(1, -1)
|
446 |
+
|
447 |
|
448 |
def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
449 |
# Compute cosine similarities
|
|
|
455 |
# Return the top N indices
|
456 |
return sorted_indices
|
457 |
|
458 |
+
|
459 |
def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
|
460 |
+
idf = tf_idf_components['idf']
|
461 |
+
term_to_column_index = tf_idf_components['tag_to_column_index']
|
462 |
+
row_to_tag = tf_idf_components['row_to_tag']
|
463 |
+
reduced_matrix = tf_idf_components['reduced_matrix']
|
464 |
+
svd = tf_idf_components['svd_model']
|
465 |
+
|
466 |
+
# Construct the TF-IDF vector
|
467 |
+
pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
|
468 |
+
|
469 |
+
# Reduce the dimensionality of the pseudo-document vector for the reduced matrix
|
470 |
+
reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector)
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
+
# Compute cosine similarities in the reduced space
|
473 |
+
cosine_similarities_reduced = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
|
474 |
+
|
475 |
+
# Sort the indices by descending cosine similarity
|
476 |
+
top_indices_reduced = np.argsort(cosine_similarities_reduced)
|
477 |
+
|
478 |
+
# Map indices to tags with their similarities
|
479 |
+
tag_similarity_dict = {row_to_tag[i]: cosine_similarities_reduced[i] for i in top_indices_reduced if i in row_to_tag}
|
480 |
|
|
|
|
|
481 |
if not allow_nsfw_tags:
|
482 |
+
tag_similarity_dict = {tag: sim for tag, sim in tag_similarity_dict.items() if tag not in nsfw_tags}
|
483 |
|
484 |
+
tag_similarity_dict = {"by " + tag if is_artist(tag) else tag: sim for tag, sim in tag_similarity_dict.items()}
|
485 |
+
|
486 |
+
# Sort and transform tag names
|
487 |
sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
|
488 |
+
transformed_sorted_tag_similarity_dict = OrderedDict(
|
489 |
+
(key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), value)
|
490 |
+
for key, value in sorted_tag_similarity_dict.items()
|
491 |
+
)
|
492 |
|
493 |
+
return transformed_sorted_tag_similarity_dict
|
494 |
|
495 |
|
496 |
def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
|
|
|
602 |
# Modify the tag
|
603 |
modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
|
604 |
artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
|
605 |
+
tf_idf_matrix_tag = re.sub(r'\\([()])', r'\1', re.sub(r' ', '_', tag_text.strip().removeprefix('by ').removeprefix('by_')))
|
606 |
# Calculate the end position based on the original tag length
|
607 |
end_pos = start_pos + len(tag_text)
|
608 |
# Append the structured data for each tag
|
|
|
612 |
"end_pos": end_pos,
|
613 |
"modified_tag": modified_tag,
|
614 |
"artist_matrix_tag": artist_matrix_tag,
|
615 |
+
"tf_idf_matrix_tag": tf_idf_matrix_tag,
|
616 |
"node_type": nodetype
|
617 |
})
|
618 |
return tag_data
|
|
|
668 |
suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
669 |
|
670 |
suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
|
671 |
+
suggested_tags = get_tfidf_reduced_similar_tags([item["tf_idf_matrix_tag"] for item in tag_data], allow_nsfw_tags)
|
672 |
+
|
673 |
+
# Create a set of tags that should be filtered out
|
674 |
+
filter_tags = {entry["original_tag"].strip() for entry in tag_data}
|
675 |
+
# Use this set to filter suggested_tags
|
676 |
+
suggested_tags_filtered = OrderedDict((k, v) for k, v in suggested_tags.items() if k not in filter_tags)
|
677 |
+
|
678 |
topnsuggestions = list(islice(suggested_tags_filtered.items(), 100))
|
679 |
suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
680 |
|
|
|
712 |
#gr.Image(label=" ", value=image_path, height=155, width=140)
|
713 |
#gr.HTML('<div style="text-align: center;"><img src={image_path} alt="Cute Mascot" style="max-height: 100px; background: transparent;"></div><br>')
|
714 |
#gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
|
715 |
+
#image_path = os.path.join('mascotimages', "transparentsquirrel.png")
|
716 |
+
random_image_path = os.path.join('mascotimages', random.choice([f for f in os.listdir('mascotimages') if os.path.isfile(os.path.join('mascotimages', f))]))
|
717 |
+
with Image.open(random_image_path) as img:
|
718 |
gr.Image(value=img,show_label=False, show_download_button=False, show_share_button=False, height=200)
|
719 |
submit_button = gr.Button(variant="primary")
|
720 |
with gr.Row():
|
tf_idf_files_418.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1072321ea307c7b1e9518bb02426bede8d181ce17565721094dee674a3712e8c
|
3 |
+
size 115989585
|