|
|
|
"""MtGPT2 |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1HMq9Cp_jhqc9HlUipLXi8SC1mHQhSgGn |
|
""" |
|
|
|
|
|
|
|
import os |
|
|
|
os.system("python3 -m pip install pytorch-lightning==1.7.0 aitextgen") |
|
|
|
with open("/home/user/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", "r") as file: |
|
lines = file.readlines() |
|
|
|
|
|
|
|
|
|
lines[19] = "from lightning_utilities.core.imports import compare_version as _compare_version\n" |
|
|
|
with open("/home/user/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", "w") as file: |
|
file.writelines(lines) |
|
|
|
import sys |
|
from jinja2 import Template |
|
from aitextgen import aitextgen |
|
|
|
try: |
|
from google.colab import files |
|
except ImportError: |
|
pass |
|
|
|
|
|
ai = aitextgen(model="minimaxir/magic-the-gathering", to_gpu=False) |
|
|
|
|
|
TEMPLATE = Template( |
|
"""{{ c.name }}{% if c.manaCost %} {{ c.manaCost }}{% endif %} |
|
{{ c.type }} |
|
{{ c.text }}{% if c.power %} |
|
{{ c.power }}/{{ c.toughness }}{% endif %}{% if c.loyalty %} |
|
Loyalty: {{ c.loyalty }}{% endif %}""" |
|
) |
|
|
|
def render_card(card_dict): |
|
card = TEMPLATE.render(c=card_dict) |
|
if card_dict["name"]: |
|
card = card.replace("~", card_dict["name"]) |
|
return card |
|
|
|
def generate_cards( |
|
n_cards: int = 8, |
|
temperature: float = 0.75, |
|
name: str = "", |
|
manaCost: str = "", |
|
type: str = "", |
|
text: str = "", |
|
power: str = "", |
|
toughness: str = "", |
|
loyalty: str = "" |
|
): |
|
|
|
n_cards = int(n_cards) |
|
if n_cards < 1: |
|
n_cards = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manaCost_str = "" |
|
|
|
for char in manaCost: |
|
manaCost_str += "{" |
|
manaCost_str += char |
|
manaCost_str += "}" |
|
|
|
|
|
|
|
|
|
prompt_str = "" |
|
|
|
token_dict = { |
|
"<|name|>": name, |
|
"<|manaCost|>": manaCost_str, |
|
"<|type|>": type, |
|
"<|text|>": text, |
|
"<|power|>": power, |
|
"<|toughness|>": toughness, |
|
"<|loyalty|>": loyalty |
|
} |
|
|
|
|
|
for token, value in token_dict.items(): |
|
if value: |
|
prompt_str += f"{token}{value}" |
|
|
|
|
|
cards = ai.generate( |
|
n=n_cards, |
|
schema=True, |
|
prompt=prompt_str, |
|
temperature=temperature, |
|
return_as_list=True |
|
) |
|
|
|
cards = list(map(render_card, cards)) |
|
|
|
out_str = "\n=====\n".join(cards) |
|
|
|
replacements = { |
|
"{G}": "π²", |
|
"{U}": "π", |
|
"{R}": "π₯", |
|
"{B}": "π", |
|
"{W}": "βοΈ", |
|
"{T}": "β©οΈ", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
for key, value in replacements.items(): |
|
out_str = out_str.replace(key, value) |
|
|
|
|
|
return out_str |
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
iface = gr.Interface( |
|
fn = generate_cards, |
|
inputs=[ |
|
gr.Slider(minimum = 2, maximum=16, step=1, value=8), |
|
gr.Slider(minimum = 0.1, maximum=1.5, step=0.01, value=0.75), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
], |
|
outputs=gr.Textbox(), |
|
title = "GPT-2 Powered MTG Card Generator", |
|
description = "Enter Manacost as '2UG' for 2 colorless + Blue + Green mana. \n\n Temperature is recomended between 0.4 and 0.9. Anything above 1 will lead to random Chaos and very low values will just be boring.", |
|
show_api = True |
|
) |
|
iface.launch() |