Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,12 @@ dataset = load_dataset("cifar100")
|
|
8 |
image = dataset["train"]["fine_label"]
|
9 |
print("load and train dataset \n")
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
def classify(image):
|
12 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
13 |
-
print("feature extractor \n")
|
14 |
-
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
15 |
-
print("load model \n")
|
16 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
17 |
print("define input \n")
|
18 |
with torch.no_grad():
|
|
|
8 |
image = dataset["train"]["fine_label"]
|
9 |
print("load and train dataset \n")
|
10 |
|
11 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
12 |
+
print("feature extractor \n")
|
13 |
+
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
14 |
+
print("load model \n")
|
15 |
+
|
16 |
def classify(image):
|
|
|
|
|
|
|
|
|
17 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
18 |
print("define input \n")
|
19 |
with torch.no_grad():
|