ampehta commited on
Commit
98b26e2
ยท
2 Parent(s): 699df87 b22b27a

Merge branch 'main' of https://huggingface.co/spaces/flax-community/koclip into main

Browse files
Files changed (5) hide show
  1. app.py +0 -1
  2. embed.py +16 -11
  3. image2text.py +4 -2
  4. text2image.py +11 -8
  5. utils.py +11 -13
app.py CHANGED
@@ -3,7 +3,6 @@ import streamlit as st
3
  import image2text
4
  import text2image
5
 
6
-
7
  PAGES = {"Text to Image": text2image, "Image to Text": image2text}
8
 
9
  st.sidebar.title("Navigation")
 
3
  import image2text
4
  import text2image
5
 
 
6
  PAGES = {"Text to Image": text2image, "Image to Text": image2text}
7
 
8
  st.sidebar.title("Navigation")
embed.py CHANGED
@@ -2,21 +2,20 @@ import argparse
2
  import csv
3
  import os
4
 
 
5
  from PIL import Image
 
6
 
7
  from utils import load_model
8
- import jax.numpy as jnp
9
- from jax import jit
10
-
11
- from tqdm import tqdm
12
 
13
 
14
  def main(args):
15
  root = args.image_path
16
  files = list(os.listdir(root))
17
  for f in files:
18
- assert(f[-4:] == ".jpg")
19
  for model_name in ["koclip-base", "koclip-large"]:
 
20
  model, processor = load_model(f"koclip/{model_name}")
21
  with tqdm(total=len(files)) as pbar:
22
  for counter in range(0, len(files), args.batch_size):
@@ -24,28 +23,34 @@ def main(args):
24
  image_ids = []
25
  for idx in range(counter, min(len(files), counter + args.batch_size)):
26
  file_ = files[idx]
27
- image = Image.open(os.path.join(root, file_)).convert('RGB')
28
  images.append(image)
29
  image_ids.append(file_)
30
 
31
  pbar.update(args.batch_size)
32
  try:
33
- inputs = processor(text=[""], images=images, return_tensors="jax", padding=True)
 
 
34
  except:
35
  print(image_ids)
36
  break
37
- inputs['pixel_values'] = jnp.transpose(inputs['pixel_values'], axes=[0, 2, 3, 1])
 
 
38
  features = model(**inputs).image_embeds
39
  with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
40
  writer = csv.writer(f, delimiter="\t")
41
  for image_id, feature in zip(image_ids, features):
42
- writer.writerow([image_id, ",".join(map(lambda x: str(x), feature))])
 
 
43
 
44
 
45
  if __name__ == "__main__":
46
  parser = argparse.ArgumentParser()
47
  parser.add_argument("--batch_size", default=16)
48
- parser.add_argument("--image_path", default="images")
49
- parser.add_argument("--out_path", default="features")
50
  args = parser.parse_args()
51
  main(args)
 
2
  import csv
3
  import os
4
 
5
+ import jax.numpy as jnp
6
  from PIL import Image
7
+ from tqdm import tqdm
8
 
9
  from utils import load_model
 
 
 
 
10
 
11
 
12
  def main(args):
13
  root = args.image_path
14
  files = list(os.listdir(root))
15
  for f in files:
16
+ assert f[-4:] == ".jpg"
17
  for model_name in ["koclip-base", "koclip-large"]:
18
+ # for model_name in ["koclip-large"]:
19
  model, processor = load_model(f"koclip/{model_name}")
20
  with tqdm(total=len(files)) as pbar:
21
  for counter in range(0, len(files), args.batch_size):
 
23
  image_ids = []
24
  for idx in range(counter, min(len(files), counter + args.batch_size)):
25
  file_ = files[idx]
26
+ image = Image.open(os.path.join(root, file_)).convert("RGB")
27
  images.append(image)
28
  image_ids.append(file_)
29
 
30
  pbar.update(args.batch_size)
31
  try:
32
+ inputs = processor(
33
+ text=[""], images=images, return_tensors="jax", padding=True
34
+ )
35
  except:
36
  print(image_ids)
37
  break
38
+ inputs["pixel_values"] = jnp.transpose(
39
+ inputs["pixel_values"], axes=[0, 2, 3, 1]
40
+ )
41
  features = model(**inputs).image_embeds
42
  with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
43
  writer = csv.writer(f, delimiter="\t")
44
  for image_id, feature in zip(image_ids, features):
45
+ writer.writerow(
46
+ [image_id, ",".join(map(lambda x: str(x), feature))]
47
+ )
48
 
49
 
50
  if __name__ == "__main__":
51
  parser = argparse.ArgumentParser()
52
  parser.add_argument("--batch_size", default=16)
53
+ parser.add_argument("--image_path", default="images/val2017")
54
+ parser.add_argument("--out_path", default="features/val2017")
55
  args = parser.parse_args()
56
  main(args)
image2text.py CHANGED
@@ -7,6 +7,8 @@ def app(model_name):
7
  model, processor = load_model(model_name)
8
 
9
  st.title("Image to Text")
10
- st.markdown("""
 
11
  Some text goes in here.
12
- """)
 
 
7
  model, processor = load_model(model_name)
8
 
9
  st.title("Image to Text")
10
+ st.markdown(
11
+ """
12
  Some text goes in here.
13
+ """
14
+ )
text2image.py CHANGED
@@ -1,21 +1,22 @@
1
  import os
2
 
 
 
3
  import streamlit as st
4
 
5
- from utils import load_model, load_index
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
 
9
 
10
  def app(model_name):
11
- images_directory = 'images/val2017'
12
- features_directory = f'features/val2017/{model_name}.tsv'
13
 
14
  files, index = load_index(features_directory)
15
- model, processor = load_model(f'koclip/{model_name}')
16
 
17
  st.title("Text to Image Search Engine")
18
- st.markdown("""
 
19
  This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
20
  5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
21
  vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
@@ -27,9 +28,11 @@ def app(model_name):
27
  Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
28
 
29
  Example Queries : ์•„ํŒŒํŠธ(Apartment), ์ž๋™์ฐจ(Car), ์ปดํ“จํ„ฐ(Computer)
30
- """)
 
31
 
32
  query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
 
33
  if st.button("์งˆ๋ฌธ (Query)"):
34
  proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
35
  vec = np.asarray(model.get_text_features(**proc))
 
1
  import os
2
 
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
  import streamlit as st
6
 
7
+ from utils import load_index, load_model
 
 
8
 
9
 
10
  def app(model_name):
11
+ images_directory = "images/val2017"
12
+ features_directory = f"features/val2017/{model_name}.tsv"
13
 
14
  files, index = load_index(features_directory)
15
+ model, processor = load_model(f"koclip/{model_name}")
16
 
17
  st.title("Text to Image Search Engine")
18
+ st.markdown(
19
+ """
20
  This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
21
  5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
22
  vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
 
28
  Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
29
 
30
  Example Queries : ์•„ํŒŒํŠธ(Apartment), ์ž๋™์ฐจ(Car), ์ปดํ“จํ„ฐ(Computer)
31
+ """
32
+ )
33
 
34
  query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
35
+
36
  if st.button("์งˆ๋ฌธ (Query)"):
37
  proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
38
  vec = np.asarray(model.get_text_features(**proc))
utils.py CHANGED
@@ -1,26 +1,28 @@
1
  import nmslib
2
- import streamlit as st
3
- from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
4
  import numpy as np
 
 
5
 
6
  from koclip import FlaxHybridCLIP
7
 
 
8
  @st.cache(allow_output_mutation=True)
9
  def load_index(img_file):
10
  filenames, embeddings = [], []
11
  lines = open(img_file, "r")
12
  for line in lines:
13
- cols = line.strip().split('\t')
14
  filename = cols[0]
15
- embedding = np.array([float(x) for x in cols[1].split(',')])
16
  filenames.append(filename)
17
  embeddings.append(embedding)
18
  embeddings = np.array(embeddings)
19
- index = nmslib.init(method='hnsw', space='cosinesimil')
20
  index.addDataPointBatch(embeddings)
21
- index.createIndex({'post': 2}, print_progress=True)
22
  return filenames, index
23
 
 
24
  @st.cache(allow_output_mutation=True)
25
  def load_model(model_name="koclip/koclip-base"):
26
  assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
@@ -28,11 +30,7 @@ def load_model(model_name="koclip/koclip-base"):
28
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
29
  processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
30
  if model_name == "koclip/koclip-large":
31
- processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
32
- return model, processor
33
-
34
- @st.cache(allow_output_mutation=True)
35
- def load_model_v2(model_name="koclip/koclip"):
36
- model = FlaxHybridCLIP.from_pretrained(model_name)
37
- processor = CLIPProcessor.from_pretrained(model_name)
38
  return model, processor
 
1
  import nmslib
 
 
2
  import numpy as np
3
+ import streamlit as st
4
+ from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
5
 
6
  from koclip import FlaxHybridCLIP
7
 
8
+
9
  @st.cache(allow_output_mutation=True)
10
  def load_index(img_file):
11
  filenames, embeddings = [], []
12
  lines = open(img_file, "r")
13
  for line in lines:
14
+ cols = line.strip().split("\t")
15
  filename = cols[0]
16
+ embedding = [float(x) for x in cols[1].split(",")]
17
  filenames.append(filename)
18
  embeddings.append(embedding)
19
  embeddings = np.array(embeddings)
20
+ index = nmslib.init(method="hnsw", space="cosinesimil")
21
  index.addDataPointBatch(embeddings)
22
+ index.createIndex({"post": 2}, print_progress=True)
23
  return filenames, index
24
 
25
+
26
  @st.cache(allow_output_mutation=True)
27
  def load_model(model_name="koclip/koclip-base"):
28
  assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
 
30
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
31
  processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
32
  if model_name == "koclip/koclip-large":
33
+ processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
34
+ "google/vit-large-patch16-224"
35
+ )
 
 
 
 
36
  return model, processor