waleko commited on
Commit
673cd4d
·
1 Parent(s): 0253944
Files changed (6) hide show
  1. __init__.py +0 -0
  2. app.py +25 -0
  3. infer.py +243 -0
  4. packages.txt +3 -0
  5. requirements.txt +7 -0
  6. webui.py +198 -0
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from textwrap import dedent
3
+
4
+ import gradio as gr
5
+ from torch import cuda
6
+
7
+ from webui import build_ui, remove_darkness, get_banner
8
+
9
+ PUBLIC_DEMO = getenv("SPACE_ID") == "waleko/TikZ-Assistant"
10
+
11
+ if PUBLIC_DEMO and not cuda.is_available():
12
+ center = ".gradio-container {text-align: center}"
13
+ with gr.Blocks(css=center, theme=remove_darkness(gr.themes.Soft()), title="AutomaTikZ") as demo:
14
+ badge = "https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg"
15
+ link = "https://huggingface.co/spaces/nllg/AutomaTikZ?duplicate=true"
16
+ html = f'<a style="display:inline-block" href="{link}"> <img src="{badge}" alt="Duplicate this Space"> </a>'
17
+ message = dedent("""\
18
+ The size of our models exceeds the resource constraints offered by the
19
+ free tier of Hugging Face Spaces. For full functionality, we recommend
20
+ duplicating this space on a paid private GPU runtime.
21
+ """)
22
+ gr.Markdown(f'{get_banner()}\n{message}\n{html}')
23
+ demo.launch()
24
+ else:
25
+ build_ui(lock=False, force_light=True).queue().launch(server_name="0.0.0.0", server_port=7860)
infer.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from functools import cache, cached_property
3
+ from io import BytesIO
4
+ from os import environ
5
+ from os.path import isfile, join
6
+ from re import MULTILINE, escape, search, sub
7
+ from subprocess import CalledProcessError, DEVNULL, TimeoutExpired
8
+ from tempfile import NamedTemporaryFile, TemporaryDirectory
9
+ from typing import Optional, Union
10
+ import warnings
11
+
12
+ from PIL import Image, ImageOps
13
+ import requests
14
+ import torch
15
+ from torch.cuda import current_device, is_available as has_cuda
16
+ from transformers import TextGenerationPipeline as TGP, TextStreamer, pipeline, ImageToTextPipeline as ITP
17
+ from transformers.utils import logging
18
+ from transformers.utils.hub import is_remote_url
19
+
20
+ from pdf2image.pdf2image import convert_from_bytes
21
+ from pdfCropMargins import crop
22
+ import fitz
23
+
24
+ logger = logging.get_logger("transformers")
25
+
26
+ from os import killpg, getpgid
27
+ from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE
28
+ from signal import SIGKILL
29
+
30
+ def run(*popenargs, input=None, timeout=None, check=False, **kwargs):
31
+ with Popen(*popenargs, start_new_session=True, **kwargs) as process:
32
+ try:
33
+ stdout, stderr = process.communicate(input, timeout=timeout)
34
+ except TimeoutExpired:
35
+ killpg(getpgid(process.pid), SIGKILL)
36
+ process.wait()
37
+ raise
38
+ except:
39
+ killpg(getpgid(process.pid), SIGKILL)
40
+ raise
41
+ retcode = process.poll()
42
+ if check and retcode:
43
+ raise CalledProcessError(retcode, process.args,
44
+ output=stdout, stderr=stderr)
45
+ return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore
46
+
47
+ def check_output(*popenargs, timeout=None, **kwargs):
48
+ return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout
49
+
50
+ class PdfDocument:
51
+ def __init__(self, raw: bytes):
52
+ self.raw = raw
53
+
54
+ def save(self, filename):
55
+ with open(filename, "wb") as f:
56
+ f.write(self.raw)
57
+
58
+
59
+ class TikzDocument:
60
+ """
61
+ Faciliate some operations with TikZ code. To compile the images a full
62
+ TeXLive installation is assumed to be on the PATH. Cropping additionally
63
+ requires Ghostscript, and rasterization needs poppler (apart from the 'pdf'
64
+ optional dependencies).
65
+ """
66
+ # engines to try, could also try: https://tex.stackexchange.com/a/495999
67
+ engines = ["pdflatex", "lualatex", "xelatex"]
68
+ Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""])
69
+
70
+ def __init__(self, code: str, timeout=120):
71
+ self.code = code
72
+ self.timeout = timeout
73
+
74
+ @property
75
+ def status(self) -> int:
76
+ return self.compile().status
77
+
78
+ @property
79
+ def pdf(self) -> Optional[PdfDocument]:
80
+ return self.compile().pdf
81
+
82
+ @property
83
+ def log(self) -> str:
84
+ return self.compile().log
85
+
86
+ @property
87
+ def compiled_with_errors(self) -> bool:
88
+ return self.status != 0
89
+
90
+ @cached_property
91
+ def has_content(self) -> bool:
92
+ """true if we have an image that isn't empty"""
93
+ return (img:=self.rasterize()) is not None and img.getcolors(1) is None
94
+
95
+ @classmethod
96
+ def set_engines(cls, engines: Union[str, list]):
97
+ cls.engines = [engines] if isinstance(engines, str) else engines
98
+
99
+ @cache
100
+ def compile(self) -> "Output":
101
+ output = dict()
102
+ with TemporaryDirectory() as tmpdirname:
103
+ with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile:
104
+ codelines = self.code.split("\n")
105
+ # make sure we don't have page numbers in compiled pdf (for cropping)
106
+ codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}"))
107
+ tmpfile.write("\n".join(codelines).encode())
108
+
109
+ try:
110
+ # compile
111
+ errorln, tmppdf, outpdf = 0, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf")
112
+ open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile
113
+
114
+ def try_save_last_page():
115
+ try:
116
+ doc = fitz.open(tmppdf) # type: ignore
117
+ doc.select([len(doc)-1])
118
+ doc.save(outpdf)
119
+ except:
120
+ pass
121
+
122
+ for engine in self.engines:
123
+ try:
124
+ check_output(
125
+ cwd=tmpdirname,
126
+ timeout=self.timeout,
127
+ stderr=DEVNULL,
128
+ env=environ | dict(max_print_line="1000"), # improve formatting of log
129
+ args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name]
130
+ )
131
+ except (CalledProcessError, TimeoutExpired) as proc:
132
+ log = getattr(proc, "output", b'').decode(errors="ignore")
133
+ error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE)
134
+ # only update status and log if first error occurs later than in previous engine
135
+ if (linenr:=int(error.group(1)) if error else 0) > errorln:
136
+ errorln = linenr
137
+ output.update(status=getattr(proc, 'returncode', -1), log=log)
138
+ try_save_last_page()
139
+ else:
140
+ output.update(status=0, log='')
141
+ try_save_last_page()
142
+ break
143
+
144
+ # crop
145
+ croppdf = f"{tmpfile.name}.crop"
146
+ crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True)
147
+ if isfile(croppdf):
148
+ with open(croppdf, "rb") as pdf:
149
+ output['pdf'] = PdfDocument(pdf.read())
150
+
151
+ except (FileNotFoundError, NameError) as e:
152
+ logger.error("Missing dependencies: " + (
153
+ "Install this project with the [pdf] feature name!" if isinstance(e, NameError)
154
+ else "Did you install TeX Live?"
155
+ ))
156
+ except RuntimeError: # pdf error during cropping
157
+ pass
158
+
159
+ if output.get("status") == 0 and not output.get("pdf", None):
160
+ logger.warning("Could compile document but something seems to have gone wrong during cropping!")
161
+
162
+ return self.Output(**output)
163
+
164
+ def rasterize(self, size=336, expand_to_square=True) -> Optional[Image.Image]:
165
+ if self.pdf:
166
+ image = convert_from_bytes(self.pdf.raw, size=size, single_file=True)[0]
167
+ if expand_to_square:
168
+ image = ImageOps.pad(image, (size, size), color='white')
169
+
170
+ return image
171
+
172
+ def save(self, filename: str, *args, **kwargs):
173
+ match filename.split(".")[-1]:
174
+ case "tex": content = self.code.encode()
175
+ case "pdf": content = getattr(self.pdf, "raw", bytes())
176
+ case fmt if img := self.rasterize(*args, **kwargs):
177
+ img.save(imgByteArr:=BytesIO(), format=fmt)
178
+ content = imgByteArr.getvalue()
179
+ case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!")
180
+
181
+ with open(filename, "wb") as f:
182
+ f.write(content)
183
+
184
+
185
+ class TikzGenerator:
186
+ def __init__(
187
+ self,
188
+ pipe: ITP,
189
+ temperature: float = 0.8, # based on "a systematic evaluation of large language models of code"
190
+ top_p: float = 0.95,
191
+ top_k: int = 0,
192
+ stream: bool = False,
193
+ expand_to_square: bool = False,
194
+ clean_up_output: bool = True,
195
+ ):
196
+ self.expand_to_square = expand_to_square
197
+ self.clean_up_output = clean_up_output
198
+ self.pipeline = pipe
199
+ self.pipeline.model = torch.compile(model) # type: ignore
200
+
201
+ self.default_kwargs = dict(
202
+ temperature=temperature,
203
+ top_p=top_p,
204
+ top_k=top_k,
205
+ num_return_sequences=1,
206
+ max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
207
+ do_sample=True,
208
+ return_full_text=False,
209
+ streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
210
+ skip_prompt=True,
211
+ skip_special_tokens=True
212
+ ),
213
+ )
214
+
215
+ if not stream:
216
+ self.default_kwargs.pop("streamer")
217
+
218
+ def generate(self, image: Image.Image, **generate_kwargs):
219
+ prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
220
+ tokenizer = self.pipeline.tokenizer
221
+ text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
222
+
223
+ if self.clean_up_output:
224
+ for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
225
+ # remove leading characters because skip_special_tokens in pipeline
226
+ # adds unwanted prefix spaces if prompt ends with a special tokens
227
+ if text and text[0].isspace() and token in tokenizer.all_special_tokens: # type: ignore
228
+ text = text[1:]
229
+ else:
230
+ break
231
+
232
+ # occasionally observed artifacts
233
+ artifacts = {
234
+ r'\bamsop\b': 'amsopn'
235
+ }
236
+ for artifact, replacement in artifacts.items():
237
+ text = sub(artifact, replacement, text) # type: ignore
238
+
239
+ return text
240
+
241
+
242
+ def __call__(self, *args, **kwargs):
243
+ return self.generate(*args, **kwargs)
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ texlive-full
2
+ ghostscript
3
+ poppler-utils
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch<2.1
2
+ pdfCropMargins~=2.0
3
+ pdf2image~=1.16
4
+ PyMuPDF~=1.22
5
+ peft>=0.2.0
6
+ transformers
7
+ gradio
webui.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/env python
3
+
4
+ from argparse import ArgumentParser
5
+ from functools import lru_cache
6
+ from importlib.resources import files
7
+ from inspect import signature
8
+ from multiprocessing.pool import ThreadPool
9
+ from tempfile import NamedTemporaryFile
10
+ from textwrap import dedent
11
+ from typing import Optional
12
+
13
+ from PIL import Image
14
+ import fitz
15
+ import gradio as gr
16
+ from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline
17
+
18
+ from infer import TikzDocument, TikzGenerator
19
+
20
+ # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
21
+ models = {
22
+ "Fine-tuned Llava": "waleko/TikZ-llava-1.5-7b"
23
+ }
24
+
25
+
26
+ @lru_cache(maxsize=1)
27
+ def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
28
+ gr.Info("Instantiating model. Could take a while...") # type: ignore
29
+ # noinspection PyTypeChecker
30
+ return pipeline("image-to-text", model=model_name, **kwargs)
31
+
32
+
33
+ def convert_to_svg(pdf):
34
+ doc = fitz.open("pdf", pdf.raw) # type: ignore
35
+ return doc[0].get_svg_image()
36
+
37
+
38
+ def inference(
39
+ model_name: str,
40
+ image: Image.Image,
41
+ temperature: float,
42
+ top_p: float,
43
+ top_k: int,
44
+ expand_to_square: bool,
45
+ ):
46
+ generate = TikzGenerator(
47
+ cached_load(model_name, device_map="auto"),
48
+ temperature=temperature,
49
+ top_p=top_p,
50
+ top_k=top_k,
51
+ expand_to_square=expand_to_square,
52
+ )
53
+ streamer = TextIteratorStreamer(
54
+ generate.pipeline.tokenizer, # type: ignore
55
+ skip_prompt=True,
56
+ skip_special_tokens=True
57
+ )
58
+
59
+ thread = ThreadPool(processes=1)
60
+ async_result = thread.apply_async(generate, kwds=dict(image=image, streamer=streamer))
61
+
62
+ generated_text = ""
63
+ for new_text in streamer:
64
+ generated_text += new_text
65
+ yield generated_text, None, False
66
+ yield async_result.get().code, None, True
67
+
68
+ def tex_compile(
69
+ code: str,
70
+ timeout: int,
71
+ rasterize: bool
72
+ ):
73
+ tikzdoc = TikzDocument(code, timeout=timeout)
74
+ if not tikzdoc.has_content:
75
+ if tikzdoc.compiled_with_errors:
76
+ raise gr.Error("TikZ code did not compile!") # type: ignore
77
+ else:
78
+ gr.Warning("TikZ code compiled to an empty image!") # type: ignore
79
+ elif tikzdoc.compiled_with_errors:
80
+ gr.Warning("TikZ code compiled with errors!") # type: ignore
81
+
82
+ if rasterize:
83
+ yield tikzdoc.rasterize()
84
+ else:
85
+ with NamedTemporaryFile(suffix=".svg", buffering=0) as tmpfile:
86
+ if pdf:=tikzdoc.pdf:
87
+ tmpfile.write(convert_to_svg(pdf).encode())
88
+ yield tmpfile.name if pdf else None
89
+
90
+ def check_inputs(image: Image.Image):
91
+ if not image:
92
+ raise gr.Error("Image is required")
93
+
94
+ def get_banner():
95
+ return dedent('''\
96
+ # AutomaTi*k*Z: Text-Guided Synthesis of Scientific Vector Graphics with Ti*k*Z
97
+
98
+ <p>
99
+ <a style="display:inline-block" href="https://github.com/potamides/AutomaTikZ">
100
+ <img src="https://img.shields.io/badge/View%20on%20GitHub-green?logo=github&labelColor=gray" alt="View on GitHub">
101
+ </a>
102
+ <a style="display:inline-block" href="https://arxiv.org/abs/2310.00367">
103
+ <img src="https://img.shields.io/badge/View%20on%20arXiv-B31B1B?logo=arxiv&labelColor=gray" alt="View on arXiv">
104
+ </a>
105
+ <a style="display:inline-block" href="https://colab.research.google.com/drive/14S22x_8VohMr9pbnlkB4FqtF4n81khIh">
106
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
107
+ </a>
108
+ <a style="display:inline-block" href="https://huggingface.co/spaces/nllg/AutomaTikZ">
109
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg" alt="Open in HF Spaces">
110
+ </a>
111
+ </p>
112
+ ''')
113
+
114
+ def remove_darkness(stylable):
115
+ """
116
+ Patch gradio to only contain light mode colors.
117
+ """
118
+ if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
119
+ params = signature(stylable.set).parameters
120
+ colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
121
+ return stylable.set(**colors)
122
+ elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
123
+ stylable.load(_js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))")
124
+ return stylable
125
+ else:
126
+ raise ValueError
127
+
128
+ def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120):
129
+ theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft()
130
+ with gr.Blocks(theme=theme, title="AutomaTikZ") as demo: # type: ignore
131
+ if force_light: remove_darkness(demo)
132
+ gr.Markdown(get_banner())
133
+ with gr.Row(variant="panel"):
134
+ with gr.Column():
135
+ info = (
136
+ "Describe what you want to generate. "
137
+ "Scientific graphics benefit from captions with at least 30 tokens (see examples below), "
138
+ "while simple objects work best with shorter, 2-3 word captions."
139
+ )
140
+ # caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...")
141
+ image = gr.Image(label="Image Input", type="pil", info=info)
142
+ label = "Model" + (f" ({lock_reason})" if lock else "")
143
+ model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore
144
+ with gr.Accordion(label="Advanced Options", open=False):
145
+ temperature = gr.Slider(minimum=0, maximum=2, step=0.05, value=0.8, label="Temperature")
146
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.95, label="Top-P")
147
+ top_k = gr.Slider(minimum=0, maximum=100, step=10, value=0, label="Top-K")
148
+ expand_to_square = gr.Checkbox(value=True, label="Expand image to square")
149
+ with gr.Row():
150
+ run_btn = gr.Button("Run", variant="primary")
151
+ stop_btn = gr.Button("Stop")
152
+ clear_btn = gr.ClearButton([image])
153
+ with gr.Column():
154
+ with gr.Tabs() as tabs:
155
+ with gr.TabItem(label:="TikZ Code", id=0):
156
+ info = "Source code of the generated image."
157
+ tikz_code = gr.Code(label=label, show_label=False, info=info, interactive=False)
158
+ with gr.TabItem(label:="Compiled Image", id=1):
159
+ result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
160
+ clear_btn.add([tikz_code, result_image])
161
+ # TODO: gr.Examples(examples=str(assets), inputs=[image, tikz_code, result_image])
162
+
163
+ events = list()
164
+ finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference
165
+ for listener in [run_btn.click]:
166
+ generate_event = listener(
167
+ check_inputs,
168
+ inputs=[image],
169
+ queue=False
170
+ ).success(
171
+ lambda: gr.Tabs(selected=0),
172
+ outputs=tabs, # type: ignore
173
+ queue=False
174
+ ).then(
175
+ inference,
176
+ inputs=[model, image, temperature, top_p, top_k, expand_to_square],
177
+ outputs=[tikz_code, result_image, finished]
178
+ )
179
+
180
+ def tex_compile_if_finished(finished, *args):
181
+ yield from (tex_compile(*args, timeout=timeout, rasterize=rasterize) if finished == "True" else [])
182
+
183
+ compile_event = generate_event.then(
184
+ lambda finished: gr.Tabs(selected=1) if finished == "True" else gr.Tabs(),
185
+ inputs=finished,
186
+ outputs=tabs, # type: ignore
187
+ queue=False
188
+ ).then(
189
+ tex_compile_if_finished,
190
+ inputs=[finished, tikz_code],
191
+ outputs=result_image
192
+ )
193
+ events.extend([generate_event, compile_event])
194
+
195
+ # model.select(lambda model_name: gr.Image(visible="clima" in model_name), inputs=model, outputs=image, queue=False)
196
+ for btn in [clear_btn, stop_btn]:
197
+ btn.click(fn=None, cancels=events, queue=False)
198
+ return demo