Spaces:
Runtime error
Runtime error
init
Browse files- __init__.py +0 -0
- app.py +25 -0
- infer.py +243 -0
- packages.txt +3 -0
- requirements.txt +7 -0
- webui.py +198 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import getenv
|
2 |
+
from textwrap import dedent
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from torch import cuda
|
6 |
+
|
7 |
+
from webui import build_ui, remove_darkness, get_banner
|
8 |
+
|
9 |
+
PUBLIC_DEMO = getenv("SPACE_ID") == "waleko/TikZ-Assistant"
|
10 |
+
|
11 |
+
if PUBLIC_DEMO and not cuda.is_available():
|
12 |
+
center = ".gradio-container {text-align: center}"
|
13 |
+
with gr.Blocks(css=center, theme=remove_darkness(gr.themes.Soft()), title="AutomaTikZ") as demo:
|
14 |
+
badge = "https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg"
|
15 |
+
link = "https://huggingface.co/spaces/nllg/AutomaTikZ?duplicate=true"
|
16 |
+
html = f'<a style="display:inline-block" href="{link}"> <img src="{badge}" alt="Duplicate this Space"> </a>'
|
17 |
+
message = dedent("""\
|
18 |
+
The size of our models exceeds the resource constraints offered by the
|
19 |
+
free tier of Hugging Face Spaces. For full functionality, we recommend
|
20 |
+
duplicating this space on a paid private GPU runtime.
|
21 |
+
""")
|
22 |
+
gr.Markdown(f'{get_banner()}\n{message}\n{html}')
|
23 |
+
demo.launch()
|
24 |
+
else:
|
25 |
+
build_ui(lock=False, force_light=True).queue().launch(server_name="0.0.0.0", server_port=7860)
|
infer.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from functools import cache, cached_property
|
3 |
+
from io import BytesIO
|
4 |
+
from os import environ
|
5 |
+
from os.path import isfile, join
|
6 |
+
from re import MULTILINE, escape, search, sub
|
7 |
+
from subprocess import CalledProcessError, DEVNULL, TimeoutExpired
|
8 |
+
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
9 |
+
from typing import Optional, Union
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
from PIL import Image, ImageOps
|
13 |
+
import requests
|
14 |
+
import torch
|
15 |
+
from torch.cuda import current_device, is_available as has_cuda
|
16 |
+
from transformers import TextGenerationPipeline as TGP, TextStreamer, pipeline, ImageToTextPipeline as ITP
|
17 |
+
from transformers.utils import logging
|
18 |
+
from transformers.utils.hub import is_remote_url
|
19 |
+
|
20 |
+
from pdf2image.pdf2image import convert_from_bytes
|
21 |
+
from pdfCropMargins import crop
|
22 |
+
import fitz
|
23 |
+
|
24 |
+
logger = logging.get_logger("transformers")
|
25 |
+
|
26 |
+
from os import killpg, getpgid
|
27 |
+
from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE
|
28 |
+
from signal import SIGKILL
|
29 |
+
|
30 |
+
def run(*popenargs, input=None, timeout=None, check=False, **kwargs):
|
31 |
+
with Popen(*popenargs, start_new_session=True, **kwargs) as process:
|
32 |
+
try:
|
33 |
+
stdout, stderr = process.communicate(input, timeout=timeout)
|
34 |
+
except TimeoutExpired:
|
35 |
+
killpg(getpgid(process.pid), SIGKILL)
|
36 |
+
process.wait()
|
37 |
+
raise
|
38 |
+
except:
|
39 |
+
killpg(getpgid(process.pid), SIGKILL)
|
40 |
+
raise
|
41 |
+
retcode = process.poll()
|
42 |
+
if check and retcode:
|
43 |
+
raise CalledProcessError(retcode, process.args,
|
44 |
+
output=stdout, stderr=stderr)
|
45 |
+
return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore
|
46 |
+
|
47 |
+
def check_output(*popenargs, timeout=None, **kwargs):
|
48 |
+
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout
|
49 |
+
|
50 |
+
class PdfDocument:
|
51 |
+
def __init__(self, raw: bytes):
|
52 |
+
self.raw = raw
|
53 |
+
|
54 |
+
def save(self, filename):
|
55 |
+
with open(filename, "wb") as f:
|
56 |
+
f.write(self.raw)
|
57 |
+
|
58 |
+
|
59 |
+
class TikzDocument:
|
60 |
+
"""
|
61 |
+
Faciliate some operations with TikZ code. To compile the images a full
|
62 |
+
TeXLive installation is assumed to be on the PATH. Cropping additionally
|
63 |
+
requires Ghostscript, and rasterization needs poppler (apart from the 'pdf'
|
64 |
+
optional dependencies).
|
65 |
+
"""
|
66 |
+
# engines to try, could also try: https://tex.stackexchange.com/a/495999
|
67 |
+
engines = ["pdflatex", "lualatex", "xelatex"]
|
68 |
+
Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""])
|
69 |
+
|
70 |
+
def __init__(self, code: str, timeout=120):
|
71 |
+
self.code = code
|
72 |
+
self.timeout = timeout
|
73 |
+
|
74 |
+
@property
|
75 |
+
def status(self) -> int:
|
76 |
+
return self.compile().status
|
77 |
+
|
78 |
+
@property
|
79 |
+
def pdf(self) -> Optional[PdfDocument]:
|
80 |
+
return self.compile().pdf
|
81 |
+
|
82 |
+
@property
|
83 |
+
def log(self) -> str:
|
84 |
+
return self.compile().log
|
85 |
+
|
86 |
+
@property
|
87 |
+
def compiled_with_errors(self) -> bool:
|
88 |
+
return self.status != 0
|
89 |
+
|
90 |
+
@cached_property
|
91 |
+
def has_content(self) -> bool:
|
92 |
+
"""true if we have an image that isn't empty"""
|
93 |
+
return (img:=self.rasterize()) is not None and img.getcolors(1) is None
|
94 |
+
|
95 |
+
@classmethod
|
96 |
+
def set_engines(cls, engines: Union[str, list]):
|
97 |
+
cls.engines = [engines] if isinstance(engines, str) else engines
|
98 |
+
|
99 |
+
@cache
|
100 |
+
def compile(self) -> "Output":
|
101 |
+
output = dict()
|
102 |
+
with TemporaryDirectory() as tmpdirname:
|
103 |
+
with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile:
|
104 |
+
codelines = self.code.split("\n")
|
105 |
+
# make sure we don't have page numbers in compiled pdf (for cropping)
|
106 |
+
codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}"))
|
107 |
+
tmpfile.write("\n".join(codelines).encode())
|
108 |
+
|
109 |
+
try:
|
110 |
+
# compile
|
111 |
+
errorln, tmppdf, outpdf = 0, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf")
|
112 |
+
open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile
|
113 |
+
|
114 |
+
def try_save_last_page():
|
115 |
+
try:
|
116 |
+
doc = fitz.open(tmppdf) # type: ignore
|
117 |
+
doc.select([len(doc)-1])
|
118 |
+
doc.save(outpdf)
|
119 |
+
except:
|
120 |
+
pass
|
121 |
+
|
122 |
+
for engine in self.engines:
|
123 |
+
try:
|
124 |
+
check_output(
|
125 |
+
cwd=tmpdirname,
|
126 |
+
timeout=self.timeout,
|
127 |
+
stderr=DEVNULL,
|
128 |
+
env=environ | dict(max_print_line="1000"), # improve formatting of log
|
129 |
+
args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name]
|
130 |
+
)
|
131 |
+
except (CalledProcessError, TimeoutExpired) as proc:
|
132 |
+
log = getattr(proc, "output", b'').decode(errors="ignore")
|
133 |
+
error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE)
|
134 |
+
# only update status and log if first error occurs later than in previous engine
|
135 |
+
if (linenr:=int(error.group(1)) if error else 0) > errorln:
|
136 |
+
errorln = linenr
|
137 |
+
output.update(status=getattr(proc, 'returncode', -1), log=log)
|
138 |
+
try_save_last_page()
|
139 |
+
else:
|
140 |
+
output.update(status=0, log='')
|
141 |
+
try_save_last_page()
|
142 |
+
break
|
143 |
+
|
144 |
+
# crop
|
145 |
+
croppdf = f"{tmpfile.name}.crop"
|
146 |
+
crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True)
|
147 |
+
if isfile(croppdf):
|
148 |
+
with open(croppdf, "rb") as pdf:
|
149 |
+
output['pdf'] = PdfDocument(pdf.read())
|
150 |
+
|
151 |
+
except (FileNotFoundError, NameError) as e:
|
152 |
+
logger.error("Missing dependencies: " + (
|
153 |
+
"Install this project with the [pdf] feature name!" if isinstance(e, NameError)
|
154 |
+
else "Did you install TeX Live?"
|
155 |
+
))
|
156 |
+
except RuntimeError: # pdf error during cropping
|
157 |
+
pass
|
158 |
+
|
159 |
+
if output.get("status") == 0 and not output.get("pdf", None):
|
160 |
+
logger.warning("Could compile document but something seems to have gone wrong during cropping!")
|
161 |
+
|
162 |
+
return self.Output(**output)
|
163 |
+
|
164 |
+
def rasterize(self, size=336, expand_to_square=True) -> Optional[Image.Image]:
|
165 |
+
if self.pdf:
|
166 |
+
image = convert_from_bytes(self.pdf.raw, size=size, single_file=True)[0]
|
167 |
+
if expand_to_square:
|
168 |
+
image = ImageOps.pad(image, (size, size), color='white')
|
169 |
+
|
170 |
+
return image
|
171 |
+
|
172 |
+
def save(self, filename: str, *args, **kwargs):
|
173 |
+
match filename.split(".")[-1]:
|
174 |
+
case "tex": content = self.code.encode()
|
175 |
+
case "pdf": content = getattr(self.pdf, "raw", bytes())
|
176 |
+
case fmt if img := self.rasterize(*args, **kwargs):
|
177 |
+
img.save(imgByteArr:=BytesIO(), format=fmt)
|
178 |
+
content = imgByteArr.getvalue()
|
179 |
+
case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!")
|
180 |
+
|
181 |
+
with open(filename, "wb") as f:
|
182 |
+
f.write(content)
|
183 |
+
|
184 |
+
|
185 |
+
class TikzGenerator:
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
pipe: ITP,
|
189 |
+
temperature: float = 0.8, # based on "a systematic evaluation of large language models of code"
|
190 |
+
top_p: float = 0.95,
|
191 |
+
top_k: int = 0,
|
192 |
+
stream: bool = False,
|
193 |
+
expand_to_square: bool = False,
|
194 |
+
clean_up_output: bool = True,
|
195 |
+
):
|
196 |
+
self.expand_to_square = expand_to_square
|
197 |
+
self.clean_up_output = clean_up_output
|
198 |
+
self.pipeline = pipe
|
199 |
+
self.pipeline.model = torch.compile(model) # type: ignore
|
200 |
+
|
201 |
+
self.default_kwargs = dict(
|
202 |
+
temperature=temperature,
|
203 |
+
top_p=top_p,
|
204 |
+
top_k=top_k,
|
205 |
+
num_return_sequences=1,
|
206 |
+
max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
|
207 |
+
do_sample=True,
|
208 |
+
return_full_text=False,
|
209 |
+
streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
|
210 |
+
skip_prompt=True,
|
211 |
+
skip_special_tokens=True
|
212 |
+
),
|
213 |
+
)
|
214 |
+
|
215 |
+
if not stream:
|
216 |
+
self.default_kwargs.pop("streamer")
|
217 |
+
|
218 |
+
def generate(self, image: Image.Image, **generate_kwargs):
|
219 |
+
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
|
220 |
+
tokenizer = self.pipeline.tokenizer
|
221 |
+
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
|
222 |
+
|
223 |
+
if self.clean_up_output:
|
224 |
+
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
|
225 |
+
# remove leading characters because skip_special_tokens in pipeline
|
226 |
+
# adds unwanted prefix spaces if prompt ends with a special tokens
|
227 |
+
if text and text[0].isspace() and token in tokenizer.all_special_tokens: # type: ignore
|
228 |
+
text = text[1:]
|
229 |
+
else:
|
230 |
+
break
|
231 |
+
|
232 |
+
# occasionally observed artifacts
|
233 |
+
artifacts = {
|
234 |
+
r'\bamsop\b': 'amsopn'
|
235 |
+
}
|
236 |
+
for artifact, replacement in artifacts.items():
|
237 |
+
text = sub(artifact, replacement, text) # type: ignore
|
238 |
+
|
239 |
+
return text
|
240 |
+
|
241 |
+
|
242 |
+
def __call__(self, *args, **kwargs):
|
243 |
+
return self.generate(*args, **kwargs)
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
texlive-full
|
2 |
+
ghostscript
|
3 |
+
poppler-utils
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch<2.1
|
2 |
+
pdfCropMargins~=2.0
|
3 |
+
pdf2image~=1.16
|
4 |
+
PyMuPDF~=1.22
|
5 |
+
peft>=0.2.0
|
6 |
+
transformers
|
7 |
+
gradio
|
webui.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#!/usr/bin/env python
|
3 |
+
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
from functools import lru_cache
|
6 |
+
from importlib.resources import files
|
7 |
+
from inspect import signature
|
8 |
+
from multiprocessing.pool import ThreadPool
|
9 |
+
from tempfile import NamedTemporaryFile
|
10 |
+
from textwrap import dedent
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
import fitz
|
15 |
+
import gradio as gr
|
16 |
+
from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline
|
17 |
+
|
18 |
+
from infer import TikzDocument, TikzGenerator
|
19 |
+
|
20 |
+
# assets = files(__package__) / "assets" if __package__ else files("assets") / "."
|
21 |
+
models = {
|
22 |
+
"Fine-tuned Llava": "waleko/TikZ-llava-1.5-7b"
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
@lru_cache(maxsize=1)
|
27 |
+
def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
|
28 |
+
gr.Info("Instantiating model. Could take a while...") # type: ignore
|
29 |
+
# noinspection PyTypeChecker
|
30 |
+
return pipeline("image-to-text", model=model_name, **kwargs)
|
31 |
+
|
32 |
+
|
33 |
+
def convert_to_svg(pdf):
|
34 |
+
doc = fitz.open("pdf", pdf.raw) # type: ignore
|
35 |
+
return doc[0].get_svg_image()
|
36 |
+
|
37 |
+
|
38 |
+
def inference(
|
39 |
+
model_name: str,
|
40 |
+
image: Image.Image,
|
41 |
+
temperature: float,
|
42 |
+
top_p: float,
|
43 |
+
top_k: int,
|
44 |
+
expand_to_square: bool,
|
45 |
+
):
|
46 |
+
generate = TikzGenerator(
|
47 |
+
cached_load(model_name, device_map="auto"),
|
48 |
+
temperature=temperature,
|
49 |
+
top_p=top_p,
|
50 |
+
top_k=top_k,
|
51 |
+
expand_to_square=expand_to_square,
|
52 |
+
)
|
53 |
+
streamer = TextIteratorStreamer(
|
54 |
+
generate.pipeline.tokenizer, # type: ignore
|
55 |
+
skip_prompt=True,
|
56 |
+
skip_special_tokens=True
|
57 |
+
)
|
58 |
+
|
59 |
+
thread = ThreadPool(processes=1)
|
60 |
+
async_result = thread.apply_async(generate, kwds=dict(image=image, streamer=streamer))
|
61 |
+
|
62 |
+
generated_text = ""
|
63 |
+
for new_text in streamer:
|
64 |
+
generated_text += new_text
|
65 |
+
yield generated_text, None, False
|
66 |
+
yield async_result.get().code, None, True
|
67 |
+
|
68 |
+
def tex_compile(
|
69 |
+
code: str,
|
70 |
+
timeout: int,
|
71 |
+
rasterize: bool
|
72 |
+
):
|
73 |
+
tikzdoc = TikzDocument(code, timeout=timeout)
|
74 |
+
if not tikzdoc.has_content:
|
75 |
+
if tikzdoc.compiled_with_errors:
|
76 |
+
raise gr.Error("TikZ code did not compile!") # type: ignore
|
77 |
+
else:
|
78 |
+
gr.Warning("TikZ code compiled to an empty image!") # type: ignore
|
79 |
+
elif tikzdoc.compiled_with_errors:
|
80 |
+
gr.Warning("TikZ code compiled with errors!") # type: ignore
|
81 |
+
|
82 |
+
if rasterize:
|
83 |
+
yield tikzdoc.rasterize()
|
84 |
+
else:
|
85 |
+
with NamedTemporaryFile(suffix=".svg", buffering=0) as tmpfile:
|
86 |
+
if pdf:=tikzdoc.pdf:
|
87 |
+
tmpfile.write(convert_to_svg(pdf).encode())
|
88 |
+
yield tmpfile.name if pdf else None
|
89 |
+
|
90 |
+
def check_inputs(image: Image.Image):
|
91 |
+
if not image:
|
92 |
+
raise gr.Error("Image is required")
|
93 |
+
|
94 |
+
def get_banner():
|
95 |
+
return dedent('''\
|
96 |
+
# AutomaTi*k*Z: Text-Guided Synthesis of Scientific Vector Graphics with Ti*k*Z
|
97 |
+
|
98 |
+
<p>
|
99 |
+
<a style="display:inline-block" href="https://github.com/potamides/AutomaTikZ">
|
100 |
+
<img src="https://img.shields.io/badge/View%20on%20GitHub-green?logo=github&labelColor=gray" alt="View on GitHub">
|
101 |
+
</a>
|
102 |
+
<a style="display:inline-block" href="https://arxiv.org/abs/2310.00367">
|
103 |
+
<img src="https://img.shields.io/badge/View%20on%20arXiv-B31B1B?logo=arxiv&labelColor=gray" alt="View on arXiv">
|
104 |
+
</a>
|
105 |
+
<a style="display:inline-block" href="https://colab.research.google.com/drive/14S22x_8VohMr9pbnlkB4FqtF4n81khIh">
|
106 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
|
107 |
+
</a>
|
108 |
+
<a style="display:inline-block" href="https://huggingface.co/spaces/nllg/AutomaTikZ">
|
109 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg" alt="Open in HF Spaces">
|
110 |
+
</a>
|
111 |
+
</p>
|
112 |
+
''')
|
113 |
+
|
114 |
+
def remove_darkness(stylable):
|
115 |
+
"""
|
116 |
+
Patch gradio to only contain light mode colors.
|
117 |
+
"""
|
118 |
+
if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
|
119 |
+
params = signature(stylable.set).parameters
|
120 |
+
colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
|
121 |
+
return stylable.set(**colors)
|
122 |
+
elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
|
123 |
+
stylable.load(_js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))")
|
124 |
+
return stylable
|
125 |
+
else:
|
126 |
+
raise ValueError
|
127 |
+
|
128 |
+
def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120):
|
129 |
+
theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft()
|
130 |
+
with gr.Blocks(theme=theme, title="AutomaTikZ") as demo: # type: ignore
|
131 |
+
if force_light: remove_darkness(demo)
|
132 |
+
gr.Markdown(get_banner())
|
133 |
+
with gr.Row(variant="panel"):
|
134 |
+
with gr.Column():
|
135 |
+
info = (
|
136 |
+
"Describe what you want to generate. "
|
137 |
+
"Scientific graphics benefit from captions with at least 30 tokens (see examples below), "
|
138 |
+
"while simple objects work best with shorter, 2-3 word captions."
|
139 |
+
)
|
140 |
+
# caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...")
|
141 |
+
image = gr.Image(label="Image Input", type="pil", info=info)
|
142 |
+
label = "Model" + (f" ({lock_reason})" if lock else "")
|
143 |
+
model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore
|
144 |
+
with gr.Accordion(label="Advanced Options", open=False):
|
145 |
+
temperature = gr.Slider(minimum=0, maximum=2, step=0.05, value=0.8, label="Temperature")
|
146 |
+
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.95, label="Top-P")
|
147 |
+
top_k = gr.Slider(minimum=0, maximum=100, step=10, value=0, label="Top-K")
|
148 |
+
expand_to_square = gr.Checkbox(value=True, label="Expand image to square")
|
149 |
+
with gr.Row():
|
150 |
+
run_btn = gr.Button("Run", variant="primary")
|
151 |
+
stop_btn = gr.Button("Stop")
|
152 |
+
clear_btn = gr.ClearButton([image])
|
153 |
+
with gr.Column():
|
154 |
+
with gr.Tabs() as tabs:
|
155 |
+
with gr.TabItem(label:="TikZ Code", id=0):
|
156 |
+
info = "Source code of the generated image."
|
157 |
+
tikz_code = gr.Code(label=label, show_label=False, info=info, interactive=False)
|
158 |
+
with gr.TabItem(label:="Compiled Image", id=1):
|
159 |
+
result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
|
160 |
+
clear_btn.add([tikz_code, result_image])
|
161 |
+
# TODO: gr.Examples(examples=str(assets), inputs=[image, tikz_code, result_image])
|
162 |
+
|
163 |
+
events = list()
|
164 |
+
finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference
|
165 |
+
for listener in [run_btn.click]:
|
166 |
+
generate_event = listener(
|
167 |
+
check_inputs,
|
168 |
+
inputs=[image],
|
169 |
+
queue=False
|
170 |
+
).success(
|
171 |
+
lambda: gr.Tabs(selected=0),
|
172 |
+
outputs=tabs, # type: ignore
|
173 |
+
queue=False
|
174 |
+
).then(
|
175 |
+
inference,
|
176 |
+
inputs=[model, image, temperature, top_p, top_k, expand_to_square],
|
177 |
+
outputs=[tikz_code, result_image, finished]
|
178 |
+
)
|
179 |
+
|
180 |
+
def tex_compile_if_finished(finished, *args):
|
181 |
+
yield from (tex_compile(*args, timeout=timeout, rasterize=rasterize) if finished == "True" else [])
|
182 |
+
|
183 |
+
compile_event = generate_event.then(
|
184 |
+
lambda finished: gr.Tabs(selected=1) if finished == "True" else gr.Tabs(),
|
185 |
+
inputs=finished,
|
186 |
+
outputs=tabs, # type: ignore
|
187 |
+
queue=False
|
188 |
+
).then(
|
189 |
+
tex_compile_if_finished,
|
190 |
+
inputs=[finished, tikz_code],
|
191 |
+
outputs=result_image
|
192 |
+
)
|
193 |
+
events.extend([generate_event, compile_event])
|
194 |
+
|
195 |
+
# model.select(lambda model_name: gr.Image(visible="clima" in model_name), inputs=model, outputs=image, queue=False)
|
196 |
+
for btn in [clear_btn, stop_btn]:
|
197 |
+
btn.click(fn=None, cancels=events, queue=False)
|
198 |
+
return demo
|