File size: 13,656 Bytes
379b837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6013c6d
379b837
 
 
 
 
 
6013c6d
379b837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import pandas
import numpy
import pandas.io.formats.style
import random
import functools
from typing import Callable, Literal

DATA_FOLDER = "."

CAT_GENERAL = 0
CAT_ARTIST = 1
CAT_UNUSED = 2
CAT_COPYRIGHT = 3
CAT_CHARACTER = 4
CAT_SPECIES = 5
CAT_INVALID = 6
CAT_META = 7
CAT_LORE = 8

CATEGORY_COLORS = {
    CAT_GENERAL: "#808080",
    CAT_ARTIST: "#f2ac08",
    CAT_UNUSED: "#ff3d3d",
    CAT_COPYRIGHT: "#d0d",
    CAT_CHARACTER: "#0a0",
    CAT_SPECIES: "#ed5d1f",
    CAT_INVALID: "#ff3d3d",
    CAT_META: "#04f",
    CAT_LORE: "#282"
}

def get_feather(filename: str) -> pandas.DataFrame:
    return pandas.read_feather(f"{DATA_FOLDER}/{filename}.feather")

tags = get_feather("tags")
posts_by_tag = get_feather("posts_by_tag").set_index("tag_id")
tags_by_post = get_feather("tags_by_post").set_index("post_id")
tag_ratings = get_feather("tag_ratings")
implications = get_feather("implications")
tags_by_name = tags.copy(deep=True)
tags_by_name.set_index("name", inplace=True)
tags.set_index("tag_id", inplace=True)

@functools.cache
def get_related_tags(targets: tuple[str, ...], exclude: tuple[str, ...] = (), samples: int = 100_000) -> pandas.DataFrame:
    these_tags = tags_by_name.loc[list(targets)]
    posts_with_these_tags = posts_by_tag.loc[these_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.intersection(*x))["post_id"][True]
    if (len(exclude) > 0):
        excluded_tags = tags_by_name.loc[list(exclude)]
        posts_with_excluded_tags = posts_by_tag.loc[excluded_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.union(*x))["post_id"][True]
        posts_with_these_tags = posts_with_these_tags - posts_with_excluded_tags
    total_post_count_together = len(posts_with_these_tags)
    sample_posts = random.sample(list(posts_with_these_tags), samples) if total_post_count_together > samples else list(posts_with_these_tags)
    post_count_together = len(sample_posts)
    sample_ratio = post_count_together / total_post_count_together
    tags_in_these_posts = tags_by_post.loc[sample_posts]
    counts_in_these_posts = tags_in_these_posts["tag_id"].explode().value_counts().rename("overlap")
    summaries = pandas.DataFrame(counts_in_these_posts).join(tags[tags["post_count"]>0], how="right").fillna(0)
    summaries["overlap"] = numpy.minimum(summaries["overlap"] / sample_ratio, summaries["post_count"])
    summaries = summaries[["category", "name", "overlap", "post_count"]]
    # Old "interestingness" value, didn't give as good results as an actual statistical technique, go figure. Code kept for curiosity's sake.
    #summaries["interestingness"] = summaries["overlap"].pow(2) / (total_post_count_together * summaries["post_count"])
    # Phi coefficient stuff.
    n = float(len(tags_by_post))
    n11 = summaries["overlap"]
    n1x = float(total_post_count_together)
    nx1 = summaries["post_count"].astype("float64")
    summaries["correlation"] = (n * n11 - n1x * nx1) / numpy.sqrt(n1x * nx1 * (n - n1x) * (n - nx1))
    return summaries

def format_tags(styler: pandas.io.formats.style.Styler):
    styler.apply(lambda row: numpy.where(row.index == "name", "color:"+CATEGORY_COLORS[row["category"]], ""), axis=1)
    styler.hide(level=0)
    styler.hide("category",axis=1)
    if 'overlap' in styler.data:
        styler.format("{:.0f}".format, subset=["overlap"])
    if 'correlation' in styler.data:
        styler.format("{:.2f}".format, subset=["correlation"])
        styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["correlation"])
    if 'score' in styler.data:
        styler.format("{:.2f}".format, subset=["score"])
        styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["score"])
    return styler

def related_tags(*targets: str, exclude: tuple[str, ...] = (), category: int = None, samples: int = 100_000, min_overlap: int = 5, min_posts: int = 20, top: int = 30, bottom: int = 0) -> pandas.DataFrame:
    result = get_related_tags(targets, exclude=exclude, samples=samples)
    if category != None:
        result = result[result["category"] == category]
    result = result[~result["name"].isin(targets)]
    result = result[result["overlap"] >= min_overlap]
    result = result[result["post_count"] >= min_posts]
    top_part = result.sort_values("correlation", ascending=False)[:top]
    bottom_part = result.sort_values("correlation", ascending=True)[:bottom].sort_values("correlation", ascending=False)
    return pandas.concat([top_part, bottom_part]).style.pipe(format_tags)

def implications_for(*subjects: str, seen: set[str] = None):
    if seen is None:
        seen = set()
    for subject in subjects:
        found = tags.loc[list(implications[implications["antecedent_id"] == tags_by_name.loc[subject, "tag_id"]].loc[:,"consequent_id"]), "name"].values
        for f in found:
            if f in seen:
                pass
            else:
                yield f
                seen.add(f)
                yield from implications_for(f, seen=seen)

def parse_tag(potential_tag: str):
    potential_tag = potential_tag.strip().replace(" ", "_").replace("\\(", "(").replace("\\)", ")")
    if potential_tag == "":
        return None
    elif potential_tag in tags_by_name.index:
        return potential_tag
    elif potential_tag.startswith("by_") and potential_tag[3:] in tags_by_name.index:
        return potential_tag[3:]
    else:
        print(f"Couldn't find tag '{potential_tag}', skipping it.")

def parse_tags(*parts: str):
    for part in parts:
        for potential_tag in part.split(","):
            tag = parse_tag(potential_tag)
            if tag is not None:
                yield tag

def add_suggestions(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples : int, min_posts: int, rating: Literal['s', 'q', 'e']):
    if isinstance(new_tags, str):
        new_tags = [new_tags]
    for new_tag in new_tags:
        related = get_related_tags((new_tag,), samples=samples)
        # Implementing the rating filter this way is horribly inefficient, fix it later
        if rating == 's':
            related = related.join(tag_ratings.set_index("tag_id"), on="tag_id")
            related["post_count"] = related["s"]
            related = related.drop("s", axis=1)
            related = related.drop("q", axis=1)
            related = related.drop("e", axis=1)
        elif rating == 'q':
            related = related.join(tag_ratings.set_index("tag_id"), on="tag_id")
            related["post_count"] = related["s"] + related["q"]
            related = related.drop("s", axis=1)
            related = related.drop("q", axis=1)
            related = related.drop("e", axis=1)
        related = related[related["post_count"] >= min_posts]
        if suggestions is None:
            suggestions = related.rename(columns={"correlation": "score"})
        else:
            suggestions = suggestions.join(related, rsuffix="r")
            # This is a totally made up way to combine correlations. It keeps them from going outside the +/- 1 range, which is nice. It also makes older
            # tags less important every time newer ones are added. That could be considered a feature or not.
            suggestions["score"] = numpy.real(numpy.power((numpy.sqrt(suggestions["score"] + 0j) + numpy.sqrt(multiplier * suggestions["correlation"] + 0j)) / 2, 2))
    return suggestions[["category", "name", "post_count", "score"]]



def pick_tags(suggestions: pandas.DataFrame, category: int, count: int, from_top: int, excluding: list[str], weighted: bool = True):
    options = suggestions[(True if category is None else suggestions["category"] == category) & (suggestions["score"] > 0) & ~suggestions["name"].isin(excluding)].sort_values("score", ascending=False)[:from_top]
    if weighted:
        values = list(options["name"].values)
        weights = list(options["score"].values)
        choices = []
        for _ in range(count):
            choice = random.choices(population=values, weights=weights, k=1)[0]
            weights.pop(values.index(choice))
            values.remove(choice)
            choices.append(choice)
        return choices
    else:
        return random.sample(list(options["name"].values), count)

def tag_to_prompt(tag: str) -> str:
    if (tags_by_name.loc[tag]["category"] == CAT_ARTIST):
        tag = "by " + tag
    return tag.replace("_", " ").replace("(" , "\\(").replace(")" , "\\)")

# A lambda in a for loop doesn't capture variables the way I want it to, so this is a method now
def add_suggestions_later(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples: int, min_posts: int, rating: Literal['s', 'q', 'e']):
    return lambda: add_suggestions(suggestions, new_tags, multiplier, samples, min_posts, rating)


Prompt = tuple[list[str], list[str], Callable[[], pandas.DataFrame]]

class PromptBuilder:
    prompts: list[Prompt]
    samples: int
    min_posts: int
    rating: Literal['s', 'q', 'e']
    skip_list: list[str]

    def __init__(self, prompts = [([],[],lambda: None)], skip=[], samples = 100_000, min_posts = 20, rating: Literal['s', 'q', 'e'] = 'e'):
        self.prompts = prompts
        self.samples = samples
        self.min_posts = min_posts
        self.rating = rating
        self.skip_list = skip

    def include(self, tag: str):
        return PromptBuilder(prompts=[
            (tag_list + [tag], negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating))
            for (tag_list, negative_list, suggestions) in self.prompts
        ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def focus(self, tag: str):
        return PromptBuilder(prompts=[
            (tag_list, negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating))
            for (tag_list, negative_list, suggestions) in self.prompts
        ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def exclude(self, tag: str):
        return PromptBuilder(prompts=[
            (tag_list, negative_list + [tag], add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating))
            for (tag_list, negative_list, suggestions) in self.prompts
        ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def avoid(self, tag: str):
        return PromptBuilder(prompts=[
            (tag_list, negative_list, add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating))
            for (tag_list, negative_list, suggestions) in self.prompts
        ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def pick(self, category: int, count: int, from_top: int):
        new_prompts = self.prompts
        for _ in range(count):
            new_prompts = [
                (tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating))
                for (tag_list, negative_list, suggestions) in new_prompts
                for s in (suggestions(),)
                for tag in pick_tags(s, category, 1, from_top, tag_list + negative_list + self.skip_list)
            ]
        return PromptBuilder(new_prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def foreach_pick(self, category: int, count: int, from_top: int):
        return PromptBuilder(prompts=[
            (tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating))
            for (tag_list, negative_list, suggestions) in self.prompts
            for s in (suggestions(),)
            for tag in pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list)
        ], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
    
    def pick_fast(self, category: int, count: int, from_top: int):
        prompts = []
        for (tag_list, negative_list, suggestions) in self.prompts:
            s = suggestions()
            new_tags = pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list)
            prompts.append((tag_list + new_tags, negative_list, add_suggestions_later(s, new_tags, 1, self.samples, self.min_posts, self.rating)))
        return PromptBuilder(prompts=prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def branch(self, count: int):
        return PromptBuilder(prompts=[prompt for prompt in self.prompts for _ in range(count)], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)

    def build(self):
        for (tag_list, negative_list, _) in self.prompts:
            positive_prompt = ", ".join([ tag_to_prompt(tag) for tag in tag_list])
            negative_prompt = ", ".join([ tag_to_prompt(tag) for tag in negative_list])
            if negative_prompt:
                yield f"{positive_prompt}\nNegative prompt: {negative_prompt}"
            else:
                yield positive_prompt

    def print(self):
        for prompt in self.build():
            print(prompt)

    def get_one(self):
        for prompt in self.build():
            return prompt