Commit
·
a7a4721
1
Parent(s):
baa2ff5
Changing weights and fixes
Browse files- model_large_caption.pth +3 -0
- models/blip_decoder.py +2 -2
- pipeline.py +3 -4
model_large_caption.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d79b3b7c41478b5fe55c35b73ca6f3525a09708289371c6c0fac641e588287e
|
3 |
+
size 1785411505
|
models/blip_decoder.py
CHANGED
@@ -8,8 +8,8 @@
|
|
8 |
import warnings
|
9 |
warnings.filterwarnings("ignore")
|
10 |
|
11 |
-
from vit import VisionTransformer, interpolate_pos_embed
|
12 |
-
from med import BertConfig, BertModel, BertLMHeadModel
|
13 |
from transformers import BertTokenizer
|
14 |
|
15 |
import torch
|
|
|
8 |
import warnings
|
9 |
warnings.filterwarnings("ignore")
|
10 |
|
11 |
+
from models.vit import VisionTransformer, interpolate_pos_embed
|
12 |
+
from models.med import BertConfig, BertModel, BertLMHeadModel
|
13 |
from transformers import BertTokenizer
|
14 |
|
15 |
import torch
|
pipeline.py
CHANGED
@@ -10,12 +10,11 @@ from torchvision import transforms
|
|
10 |
from torchvision.transforms.functional import InterpolationMode
|
11 |
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
-
print(device)
|
14 |
|
15 |
class PreTrainedPipeline():
|
16 |
-
def __init__(self):
|
17 |
# load the optimized model
|
18 |
-
self.model_path = '
|
19 |
self.model = blip_decoder(
|
20 |
pretrained=self.model_path,
|
21 |
image_size=384,
|
@@ -34,7 +33,7 @@ class PreTrainedPipeline():
|
|
34 |
|
35 |
|
36 |
|
37 |
-
def __call__(self, data: Any) -> Dict[str]:
|
38 |
"""
|
39 |
Args:
|
40 |
data (:obj:):
|
|
|
10 |
from torchvision.transforms.functional import InterpolationMode
|
11 |
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
13 |
|
14 |
class PreTrainedPipeline():
|
15 |
+
def __init__(self, path=""):
|
16 |
# load the optimized model
|
17 |
+
self.model_path = 'model_large_caption.pth'
|
18 |
self.model = blip_decoder(
|
19 |
pretrained=self.model_path,
|
20 |
image_size=384,
|
|
|
33 |
|
34 |
|
35 |
|
36 |
+
def __call__(self, data: Any) -> Dict[str, Any]:
|
37 |
"""
|
38 |
Args:
|
39 |
data (:obj:):
|