chanhua commited on
Commit
a814c2f
1 Parent(s): 8b93855

Upload 2 files

Browse files
Files changed (1) hide show
  1. image_feature.py +6 -3
image_feature.py CHANGED
@@ -49,8 +49,10 @@ DEVICE = torch.device('cpu')
49
 
50
 
51
  # 第二种方式推理图片相似度
52
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
53
- model = AutoModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
 
 
54
  # processor = AutoImageProcessor.from_pretrained("chanhua/autotrain-izefx-v3qh0")
55
  # model = AutoModel.from_pretrained("chanhua/autotrain-izefx-v3qh0").to(DEVICE)
56
 
@@ -59,7 +61,8 @@ model = AutoModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
59
 
60
  # pipe = pipeline(task="image-feature-extraction", model_name="google/vit-base-patch16-384", device=DEVICE, pool=True)
61
  # pipe = pipeline(task="image-feature-extraction", model_name="chanhua/autotrain-izefx-v3qh0", device=DEVICE, pool=True)
62
- pipe = pipeline(task="image-feature-extraction", model_name="google/vit-base-patch16-224", device=DEVICE, pool=True, revision="29e7a1e183")
 
63
 
64
 
65
  # 推理
 
49
 
50
 
51
  # 第二种方式推理图片相似度
52
+ # processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
53
+ # model = AutoModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
54
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
55
+ model = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k").to(DEVICE)
56
  # processor = AutoImageProcessor.from_pretrained("chanhua/autotrain-izefx-v3qh0")
57
  # model = AutoModel.from_pretrained("chanhua/autotrain-izefx-v3qh0").to(DEVICE)
58
 
 
61
 
62
  # pipe = pipeline(task="image-feature-extraction", model_name="google/vit-base-patch16-384", device=DEVICE, pool=True)
63
  # pipe = pipeline(task="image-feature-extraction", model_name="chanhua/autotrain-izefx-v3qh0", device=DEVICE, pool=True)
64
+ # pipe = pipeline(task="image-feature-extraction", model_name="google/vit-base-patch16-224", device=DEVICE, pool=True, revision="29e7a1e183")
65
+ pipe = pipeline(task="image-feature-extraction", model_name="google/vit-base-patch16-224-in21k", device=DEVICE, pool=True)
66
 
67
 
68
  # 推理