Spaces:
Sleeping
Sleeping
FoodDesert
commited on
Upload app.py
Browse files
app.py
CHANGED
@@ -130,16 +130,17 @@ parser = Lark(grammar, start='start')
|
|
130 |
|
131 |
# Function to extract tags
|
132 |
def extract_tags(tree):
|
133 |
-
|
134 |
def _traverse(node):
|
135 |
if isinstance(node, Token) and node.type == '__ANON_1':
|
136 |
-
|
|
|
|
|
137 |
elif not isinstance(node, Token):
|
138 |
for child in node.children:
|
139 |
_traverse(child)
|
140 |
-
|
141 |
_traverse(tree)
|
142 |
-
return
|
143 |
|
144 |
|
145 |
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
@@ -341,7 +342,7 @@ def geometric_mean_given_words(target_word, context_words, co_occurrence_matrix,
|
|
341 |
|
342 |
def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
|
343 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
344 |
-
html_str = f"<div style='display: inline-block; margin:
|
345 |
# Loop through the results and add table rows for each
|
346 |
for word, sim in result:
|
347 |
word_with_underscores = word.replace(' ', '_')
|
@@ -404,24 +405,35 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
404 |
if not hasattr(find_similar_tags, "tag2idwiki"):
|
405 |
find_similar_tags.tag2idwiki = build_tag_id_wiki_dict()
|
406 |
|
407 |
-
|
|
|
408 |
|
409 |
# Find similar tags and prepare data for tables
|
410 |
html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
411 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
412 |
tags_added = False
|
413 |
-
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
continue
|
416 |
|
417 |
-
modified_tag_for_search =
|
418 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
419 |
result, seen = [], set(transformed_tags)
|
420 |
|
421 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
422 |
-
if
|
423 |
result.append(modified_tag_for_search.replace('_',' '), 1)
|
424 |
-
seen.add(
|
425 |
else: #The user correctly did not put underscores in their tag
|
426 |
continue
|
427 |
else:
|
@@ -444,36 +456,60 @@ def find_similar_tags(test_tags, similarity_weight, allow_nsfw_tags):
|
|
444 |
#Adjust score based on context
|
445 |
for i in range(len(result)):
|
446 |
word, score = result[i] # Unpack the tuple
|
447 |
-
geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag !=
|
448 |
adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
|
449 |
result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
|
450 |
#print(word, score, geometric_mean, adjusted_score)
|
451 |
|
452 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
453 |
-
html_content += create_html_tables_for_tags(
|
|
|
|
|
|
|
454 |
tags_added=True
|
455 |
# If no tags were processed, add a message
|
456 |
if not tags_added:
|
457 |
html_content = create_html_placeholder(title="Unknown Tags")
|
458 |
|
459 |
-
return html_content # Return list of lists for Dataframe
|
460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
-
def find_similar_artists(
|
463 |
try:
|
464 |
-
new_tags_string =
|
465 |
new_tags_string, removed_tags = remove_special_tags(new_tags_string)
|
466 |
|
467 |
# Parse the prompt
|
468 |
parsed = parser.parse(new_tags_string)
|
469 |
# Extract tags from the parsed tree
|
470 |
new_image_tags = extract_tags(parsed)
|
471 |
-
|
472 |
|
473 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
474 |
-
unseen_tags_data = find_similar_tags(
|
|
|
|
|
475 |
|
476 |
-
|
|
|
477 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
478 |
|
479 |
top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
|
@@ -490,7 +526,7 @@ def find_similar_artists(new_tags_string, top_n, similarity_weight, allow_nsfw_t
|
|
490 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
491 |
image_galleries.append(artists) # Extend the list with artist tuples
|
492 |
|
493 |
-
return (unseen_tags_data, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
|
494 |
except ParseError as e:
|
495 |
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
496 |
|
@@ -504,6 +540,8 @@ with gr.Blocks() as app:
|
|
504 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
505 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
506 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
|
|
|
|
507 |
with gr.Row():
|
508 |
with gr.Column(scale=1):
|
509 |
top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
|
@@ -521,7 +559,7 @@ with gr.Blocks() as app:
|
|
521 |
submit_button.click(
|
522 |
find_similar_artists,
|
523 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
524 |
-
outputs=[unseen_tags, top_artists, dynamic_prompts] + galleries
|
525 |
)
|
526 |
|
527 |
gr.Markdown(faq_content)
|
|
|
130 |
|
131 |
# Function to extract tags
|
132 |
def extract_tags(tree):
|
133 |
+
tags_with_positions = []
|
134 |
def _traverse(node):
|
135 |
if isinstance(node, Token) and node.type == '__ANON_1':
|
136 |
+
tag_position = node.start_pos
|
137 |
+
tag_text = node.value.strip()
|
138 |
+
tags_with_positions.append((tag_text, tag_position))
|
139 |
elif not isinstance(node, Token):
|
140 |
for child in node.children:
|
141 |
_traverse(child)
|
|
|
142 |
_traverse(tree)
|
143 |
+
return tags_with_positions
|
144 |
|
145 |
|
146 |
special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"]
|
|
|
342 |
|
343 |
def create_html_tables_for_tags(tag, result, tag2count, tag2idwiki):
|
344 |
# Wrap the tag part in a <span> with styles for bold and larger font
|
345 |
+
html_str = f"<div style='display: inline-block; margin: 20px; vertical-align: top;'><table><thead><tr><th colspan='3' style='text-align: center; padding-bottom: 10px;'><span style='font-weight: bold; font-size: 20px;'>{tag}</span></th></tr></thead><tbody><tr style='border-bottom: 1px solid #000;'><th>Corrected Tag</th><th>Similarity</th><th>Count</th></tr>"
|
346 |
# Loop through the results and add table rows for each
|
347 |
for word, sim in result:
|
348 |
word_with_underscores = word.replace(' ', '_')
|
|
|
405 |
if not hasattr(find_similar_tags, "tag2idwiki"):
|
406 |
find_similar_tags.tag2idwiki = build_tag_id_wiki_dict()
|
407 |
|
408 |
+
modified_tags = [tag_info['modified_tag'] for tag_info in test_tags]
|
409 |
+
transformed_tags = [tag.replace(' ', '_') for tag in modified_tags]
|
410 |
|
411 |
# Find similar tags and prepare data for tables
|
412 |
html_content = "<div style='display: inline-block; margin: 20px; text-align: center;'>"
|
413 |
html_content += "<h1>Unknown Tags</h1>" # Heading for the table
|
414 |
tags_added = False
|
415 |
+
bad_entities = []
|
416 |
+
for tag_info in test_tags:
|
417 |
+
original_tag = tag_info['original_tag']
|
418 |
+
modified_tag = tag_info['modified_tag']
|
419 |
+
start_pos = tag_info['start_pos']
|
420 |
+
end_pos = tag_info['end_pos']
|
421 |
+
|
422 |
+
|
423 |
+
print(original_tag, modified_tag, start_pos, end_pos)
|
424 |
+
|
425 |
+
|
426 |
+
if modified_tag in special_tags:
|
427 |
continue
|
428 |
|
429 |
+
modified_tag_for_search = modified_tag.replace(' ','_')
|
430 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
431 |
result, seen = [], set(transformed_tags)
|
432 |
|
433 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
434 |
+
if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
|
435 |
result.append(modified_tag_for_search.replace('_',' '), 1)
|
436 |
+
seen.add(modified_tag)
|
437 |
else: #The user correctly did not put underscores in their tag
|
438 |
continue
|
439 |
else:
|
|
|
456 |
#Adjust score based on context
|
457 |
for i in range(len(result)):
|
458 |
word, score = result[i] # Unpack the tuple
|
459 |
+
geometric_mean = geometric_mean_given_words(word.replace(' ','_'), [context_tag for context_tag in transformed_tags if context_tag != word and context_tag != modified_tag], conditional_co_occurrence_matrix, conditional_vocabulary, conditional_doc_count, smoothing_value=conditional_smoothing)
|
460 |
adjusted_score = (similarity_weight * geometric_mean) + ((1-similarity_weight)*score) # Apply the adjustment function
|
461 |
result[i] = (word, adjusted_score) # Update the tuple with the adjusted score
|
462 |
#print(word, score, geometric_mean, adjusted_score)
|
463 |
|
464 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
465 |
+
html_content += create_html_tables_for_tags(modified_tag, result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
466 |
+
|
467 |
+
bad_entities.append({"entity":"UNKNOWN", "start":start_pos, "end":end_pos})
|
468 |
+
|
469 |
tags_added=True
|
470 |
# If no tags were processed, add a message
|
471 |
if not tags_added:
|
472 |
html_content = create_html_placeholder(title="Unknown Tags")
|
473 |
|
474 |
+
return html_content, bad_entities # Return list of lists for Dataframe
|
475 |
+
|
476 |
+
|
477 |
+
def build_tag_offsets_dicts(new_image_tags_with_positions):
|
478 |
+
# Structure the data for HighlightedText
|
479 |
+
tag_data = []
|
480 |
+
for tag_text, start_pos in new_image_tags_with_positions:
|
481 |
+
# Modify the tag
|
482 |
+
modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
|
483 |
+
# Calculate the end position based on the original tag length
|
484 |
+
end_pos = start_pos + len(tag_text)
|
485 |
+
# Append the structured data for each tag
|
486 |
+
tag_data.append({
|
487 |
+
"original_tag": tag_text,
|
488 |
+
"start_pos": start_pos,
|
489 |
+
"end_pos": end_pos,
|
490 |
+
"modified_tag": modified_tag
|
491 |
+
})
|
492 |
+
return tag_data
|
493 |
+
|
494 |
|
495 |
+
def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_nsfw_tags):
|
496 |
try:
|
497 |
+
new_tags_string = original_tags_string.lower()
|
498 |
new_tags_string, removed_tags = remove_special_tags(new_tags_string)
|
499 |
|
500 |
# Parse the prompt
|
501 |
parsed = parser.parse(new_tags_string)
|
502 |
# Extract tags from the parsed tree
|
503 |
new_image_tags = extract_tags(parsed)
|
504 |
+
tag_data = build_tag_offsets_dicts(new_image_tags)
|
505 |
|
506 |
###unseen_tags = list(set(OrderedDict.fromkeys(new_image_tags)) - set(vectorizer.vocabulary_.keys())) #We may want this line again later. These are the tags that were not used to calculate the artists list.
|
507 |
+
unseen_tags_data, bad_entities = find_similar_tags(tag_data, similarity_weight, allow_nsfw_tags)
|
508 |
+
|
509 |
+
bad_tags_illustrated_string = {"text":new_tags_string, "entities":bad_entities}
|
510 |
|
511 |
+
modified_tags = [tag_info['modified_tag'] for tag_info in tag_data]
|
512 |
+
X_new_image = vectorizer.transform([','.join(modified_tags + removed_tags)])
|
513 |
similarities = cosine_similarity(X_new_image, X_artist)[0]
|
514 |
|
515 |
top_artist_indices = np.argsort(similarities)[-(top_n + 1):][::-1]
|
|
|
526 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
527 |
image_galleries.append(artists) # Extend the list with artist tuples
|
528 |
|
529 |
+
return (unseen_tags_data, bad_tags_illustrated_string, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries) #image_galleries[0], image_galleries[1] DOES work. Find a generic alternative.
|
530 |
except ParseError as e:
|
531 |
return [], "Parse Error: Check for mismatched parentheses or something", "", None, None
|
532 |
|
|
|
540 |
similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Similarity weight")
|
541 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
542 |
allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
|
543 |
+
with gr.Row():
|
544 |
+
bad_tags_illustrated_string = gr.HighlightedText()
|
545 |
with gr.Row():
|
546 |
with gr.Column(scale=1):
|
547 |
top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
|
|
|
559 |
submit_button.click(
|
560 |
find_similar_artists,
|
561 |
inputs=[image_tags, num_artists, similarity_weight, allow_nsfw],
|
562 |
+
outputs=[unseen_tags, bad_tags_illustrated_string, top_artists, dynamic_prompts] + galleries
|
563 |
)
|
564 |
|
565 |
gr.Markdown(faq_content)
|