osanseviero
commited on
Commit
·
7a2c78d
1
Parent(s):
d5b4d55
Fix imports
Browse files- pipeline.py +1 -5
- vqgan_jax/__init__.py +1 -0
pipeline.py
CHANGED
@@ -73,14 +73,10 @@ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
|
73 |
|
74 |
class PreTrainedPipeline():
|
75 |
def __init__(self, path=""):
|
76 |
-
|
77 |
-
# Preload all the elements you are going to need at inference.
|
78 |
-
# For instance your model, processors, tokenizer that might be needed.
|
79 |
-
# This function is only called once, so do all the heavy processing I/O here"""
|
80 |
self.tokenizer = BartTokenizer.from_pretrained(path)
|
81 |
self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path)
|
82 |
|
83 |
-
self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9")
|
84 |
|
85 |
|
86 |
def __call__(self, inputs: str):
|
|
|
73 |
|
74 |
class PreTrainedPipeline():
|
75 |
def __init__(self, path=""):
|
76 |
+
self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9")
|
|
|
|
|
|
|
77 |
self.tokenizer = BartTokenizer.from_pretrained(path)
|
78 |
self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path)
|
79 |
|
|
|
80 |
|
81 |
|
82 |
def __call__(self, inputs: str):
|
vqgan_jax/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import *
|