cjber commited on
Commit
e13ada9
·
1 Parent(s): ea5fd7c

add prompt

Browse files
Files changed (3) hide show
  1. pyproject.toml +3 -0
  2. src/planning_ai/phi.py +28 -14
  3. uv.lock +51 -0
pyproject.toml CHANGED
@@ -12,7 +12,10 @@ dependencies = [
12
  "torch>=2.4.1",
13
  "accelerate>=0.34.0",
14
  "pillow>=10.4.0",
 
15
  "flash-attn>=2.6.3",
 
 
16
  ]
17
 
18
  [tool.uv]
 
12
  "torch>=2.4.1",
13
  "accelerate>=0.34.0",
14
  "pillow>=10.4.0",
15
+ "setuptools",
16
  "flash-attn>=2.6.3",
17
+ "torchvision>=0.19.1",
18
+ "pdf2image>=1.17.0",
19
  ]
20
 
21
  [tool.uv]
src/planning_ai/phi.py CHANGED
@@ -1,4 +1,6 @@
1
- import requests
 
 
2
  from PIL import Image
3
  from transformers import AutoModelForCausalLM, AutoProcessor
4
 
@@ -14,31 +16,43 @@ model = AutoModelForCausalLM.from_pretrained(
14
  )
15
 
16
  # for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
17
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, num_crops=4)
 
 
18
 
19
  images = []
20
  placeholder = ""
21
-
22
- # Note: if OOM, you might consider reduce number of frames in this example.
23
- path = "./data/raw/2024-09-05_10-46.png"
24
- images.append(Image.open(requests.get(path, stream=True).raw))
25
- placeholder += f"<|image_1|>\n"
 
 
 
26
 
27
  messages = [
28
- {"role": "user", "content": placeholder + "Summarize the deck of slides."},
 
 
 
 
 
 
 
 
 
 
 
29
  ]
30
 
31
  prompt = processor.tokenizer.apply_chat_template(
32
  messages, tokenize=False, add_generation_prompt=True
33
  )
34
 
35
- inputs = processor(prompt, images, return_tensors="pt").to("cuda:0")
36
 
37
- generation_args = {
38
- "max_new_tokens": 1000,
39
- "temperature": 0.0,
40
- "do_sample": False,
41
- }
42
 
43
  generate_ids = model.generate(
44
  **inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args
 
1
+ from pathlib import Path
2
+
3
+ from pdf2image import convert_from_path
4
  from PIL import Image
5
  from transformers import AutoModelForCausalLM, AutoProcessor
6
 
 
16
  )
17
 
18
  # for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
19
+ processor = AutoProcessor.from_pretrained(
20
+ model_id, trust_remote_code=True, num_crops=16
21
+ )
22
 
23
  images = []
24
  placeholder = ""
25
+ path = Path("./data/raw/pdfs")
26
+ i = 1
27
+ for file in path.glob("*.pdf"):
28
+ pdf_images = convert_from_path(file)
29
+ for image in pdf_images:
30
+ images.append(image)
31
+ placeholder += f"<|image_{i}|>\n"
32
+ i += 1
33
 
34
  messages = [
35
+ {
36
+ "role": "user",
37
+ "content": """
38
+ <|image_1|>\nThis image shows an extract from a planning response form filled out by a member of the public. They may be pro or against the planning proposal. These planning applications typically cover the construction of new buildings, or similar infrastructure.
39
+
40
+ Extract all structured information from these documents. For example a section may include a questionnaire that may or may not have been filled in. Please indicate the response from the member of public in a structured format, following the convention:
41
+
42
+ {"<question>": "<response>"}
43
+
44
+ The document may also include hand written notes under the title 'Your comments:', also include these notes verbatim, in a structured format. If a word is unreadable please use the special token <UNKNOWN>. Do not attempt to fill in the word if you are unsure.
45
+ """,
46
+ },
47
  ]
48
 
49
  prompt = processor.tokenizer.apply_chat_template(
50
  messages, tokenize=False, add_generation_prompt=True
51
  )
52
 
53
+ inputs = processor(prompt, images[0], return_tensors="pt").to("cuda:0")
54
 
55
+ generation_args = {"max_new_tokens": 1000, "do_sample": False}
 
 
 
 
56
 
57
  generate_ids = model.generate(
58
  **inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args
uv.lock CHANGED
@@ -1066,6 +1066,18 @@ wheels = [
1066
  { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 },
1067
  ]
1068
 
 
 
 
 
 
 
 
 
 
 
 
 
1069
  [[package]]
1070
  name = "pexpect"
1071
  version = "4.9.0"
@@ -1147,8 +1159,11 @@ dependencies = [
1147
  { name = "langchain-community" },
1148
  { name = "langchain-core" },
1149
  { name = "langchain-unstructured" },
 
1150
  { name = "pillow" },
 
1151
  { name = "torch" },
 
1152
  { name = "transformers" },
1153
  ]
1154
 
@@ -1165,8 +1180,11 @@ requires-dist = [
1165
  { name = "langchain-community", specifier = ">=0.2.16" },
1166
  { name = "langchain-core", specifier = ">=0.2.38" },
1167
  { name = "langchain-unstructured", specifier = ">=0.1.2" },
 
1168
  { name = "pillow", specifier = ">=10.4.0" },
 
1169
  { name = "torch", specifier = ">=2.4.1" },
 
1170
  { name = "transformers", specifier = ">=4.44.2" },
1171
  ]
1172
 
@@ -1523,6 +1541,15 @@ wheels = [
1523
  { url = "https://files.pythonhosted.org/packages/e7/50/89e5eac4120b55422450d5221c86d526ace14e222ea3f6c0c005f8f011ec/safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2c42e9b277513b81cf507e6121c7b432b3235f980cac04f39f435b7902857f91", size = 606993 },
1524
  ]
1525
 
 
 
 
 
 
 
 
 
 
1526
  [[package]]
1527
  name = "six"
1528
  version = "1.16.0"
@@ -1715,6 +1742,30 @@ wheels = [
1715
  { url = "https://files.pythonhosted.org/packages/ac/30/8b6f77ea4ce84f015ee024b8dfef0dac289396254e8bfd493906d4cbb848/torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d", size = 62123443 },
1716
  ]
1717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1718
  [[package]]
1719
  name = "tqdm"
1720
  version = "4.66.5"
 
1066
  { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 },
1067
  ]
1068
 
1069
+ [[package]]
1070
+ name = "pdf2image"
1071
+ version = "1.17.0"
1072
+ source = { registry = "https://pypi.org/simple" }
1073
+ dependencies = [
1074
+ { name = "pillow" },
1075
+ ]
1076
+ sdist = { url = "https://files.pythonhosted.org/packages/00/d8/b280f01045555dc257b8153c00dee3bc75830f91a744cd5f84ef3a0a64b1/pdf2image-1.17.0.tar.gz", hash = "sha256:eaa959bc116b420dd7ec415fcae49b98100dda3dd18cd2fdfa86d09f112f6d57", size = 12811 }
1077
+ wheels = [
1078
+ { url = "https://files.pythonhosted.org/packages/62/33/61766ae033518957f877ab246f87ca30a85b778ebaad65b7f74fa7e52988/pdf2image-1.17.0-py3-none-any.whl", hash = "sha256:ecdd58d7afb810dffe21ef2b1bbc057ef434dabbac6c33778a38a3f7744a27e2", size = 11618 },
1079
+ ]
1080
+
1081
  [[package]]
1082
  name = "pexpect"
1083
  version = "4.9.0"
 
1159
  { name = "langchain-community" },
1160
  { name = "langchain-core" },
1161
  { name = "langchain-unstructured" },
1162
+ { name = "pdf2image" },
1163
  { name = "pillow" },
1164
+ { name = "setuptools" },
1165
  { name = "torch" },
1166
+ { name = "torchvision" },
1167
  { name = "transformers" },
1168
  ]
1169
 
 
1180
  { name = "langchain-community", specifier = ">=0.2.16" },
1181
  { name = "langchain-core", specifier = ">=0.2.38" },
1182
  { name = "langchain-unstructured", specifier = ">=0.1.2" },
1183
+ { name = "pdf2image" },
1184
  { name = "pillow", specifier = ">=10.4.0" },
1185
+ { name = "setuptools" },
1186
  { name = "torch", specifier = ">=2.4.1" },
1187
+ { name = "torchvision", specifier = ">=0.19.1" },
1188
  { name = "transformers", specifier = ">=4.44.2" },
1189
  ]
1190
 
 
1541
  { url = "https://files.pythonhosted.org/packages/e7/50/89e5eac4120b55422450d5221c86d526ace14e222ea3f6c0c005f8f011ec/safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2c42e9b277513b81cf507e6121c7b432b3235f980cac04f39f435b7902857f91", size = 606993 },
1542
  ]
1543
 
1544
+ [[package]]
1545
+ name = "setuptools"
1546
+ version = "74.1.2"
1547
+ source = { registry = "https://pypi.org/simple" }
1548
+ sdist = { url = "https://files.pythonhosted.org/packages/3e/2c/f0a538a2f91ce633a78daaeb34cbfb93a54bd2132a6de1f6cec028eee6ef/setuptools-74.1.2.tar.gz", hash = "sha256:95b40ed940a1c67eb70fc099094bd6e99c6ee7c23aa2306f4d2697ba7916f9c6", size = 1356467 }
1549
+ wheels = [
1550
+ { url = "https://files.pythonhosted.org/packages/cb/9c/9ad11ac06b97e55ada655f8a6bea9d1d3f06e120b178cd578d80e558191d/setuptools-74.1.2-py3-none-any.whl", hash = "sha256:5f4c08aa4d3ebcb57a50c33b1b07e94315d7fc7230f7115e47fc99776c8ce308", size = 1262071 },
1551
+ ]
1552
+
1553
  [[package]]
1554
  name = "six"
1555
  version = "1.16.0"
 
1742
  { url = "https://files.pythonhosted.org/packages/ac/30/8b6f77ea4ce84f015ee024b8dfef0dac289396254e8bfd493906d4cbb848/torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d", size = 62123443 },
1743
  ]
1744
 
1745
+ [[package]]
1746
+ name = "torchvision"
1747
+ version = "0.19.1"
1748
+ source = { registry = "https://pypi.org/simple" }
1749
+ dependencies = [
1750
+ { name = "numpy" },
1751
+ { name = "pillow" },
1752
+ { name = "torch" },
1753
+ ]
1754
+ wheels = [
1755
+ { url = "https://files.pythonhosted.org/packages/d4/90/cab820b96d4d1a36b088774209d2379cf49eda8210c8fee13552383860b7/torchvision-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:54e8513099e6f586356c70f809d34f391af71ad182fe071cc328a28af2c40608", size = 1660236 },
1756
+ { url = "https://files.pythonhosted.org/packages/72/55/e0b3821c5595a9a2c8ec98d234b4a0d1142d91daac61f007503d3158f857/torchvision-0.19.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:20a1f5e02bfdad7714e55fa3fa698347c11d829fa65e11e5a84df07d93350eed", size = 7026373 },
1757
+ { url = "https://files.pythonhosted.org/packages/db/71/da0f71c2765feee125b1dc280a6432aa88c510aedf9a36987f3fe7ed05ea/torchvision-0.19.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:7b063116164be52fc6deb4762de7f8c90bfa3a65f8d5caf17f8e2d5aadc75a04", size = 14072253 },
1758
+ { url = "https://files.pythonhosted.org/packages/f7/8e/cbae11f8046d433881b478afc9e7589a76158124779cbc3a40163ec716bf/torchvision-0.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:f40b6acabfa886da1bc3768f47679c61feee6bde90deb979d9f300df8c8a0145", size = 1288329 },
1759
+ { url = "https://files.pythonhosted.org/packages/66/f6/a2f07a3f5385b37c45b8e14448b8610a8618dfad18ea437cb23b4edc50c5/torchvision-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:40514282b4896d62765b8e26d7091c32e17c35817d00ec4be2362ea3ba3d1787", size = 1660235 },
1760
+ { url = "https://files.pythonhosted.org/packages/28/9d/40d1b943bbbd02a30d6b4f691d6de37a7e4c92f90bed0f8f47379e90eec6/torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:5a91be061ae5d6d5b95e833b93e57ca4d3c56c5a57444dd15da2e3e7fba96050", size = 7026152 },
1761
+ { url = "https://files.pythonhosted.org/packages/36/04/36e1d35b864f4a7c8f3056a427542b14b3bcdbc66edd36faadee109b86c5/torchvision-0.19.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d71a6a6fe3a5281ca3487d4c56ad4aad20ff70f82f1d7c79bcb6e7b0c2af00c8", size = 14072255 },
1762
+ { url = "https://files.pythonhosted.org/packages/f8/69/dc769cf54df8e828c0b8957b4521f35178f5bd4cc5b8fbe8a37ffd89a27c/torchvision-0.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:70dea324174f5e9981b68e4b7cd524512c106ba64aedef560a86a0bbf2fbf62c", size = 1288330 },
1763
+ { url = "https://files.pythonhosted.org/packages/a4/d0/b1029ab95d9219cac2dfc0d835e9ab4cebb01f5cb6b48e736778020fb995/torchvision-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27ece277ff0f6cdc7fed0627279c632dcb2e58187da771eca24b0fbcf3f8590d", size = 1660230 },
1764
+ { url = "https://files.pythonhosted.org/packages/8b/34/fdd2d9e01228a069b28473a7c020bf1812c8ecab8565666feb247659ed30/torchvision-0.19.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:c659ff92a61f188a1a7baef2850f3c0b6c85685447453c03d0e645ba8f1dcc1c", size = 7026404 },
1765
+ { url = "https://files.pythonhosted.org/packages/da/b2/9da42d67dfc30d9e3b161f7a37f6c7eca86a80e6caef4a9aa11727faa4f5/torchvision-0.19.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:c07bf43c2a145d792ecd9d0503d6c73577147ece508d45600d8aac77e4cdfcf9", size = 14072022 },
1766
+ { url = "https://files.pythonhosted.org/packages/6b/b2/fd577e1622b43cdeb74782a60cea4909f88f471813c215ea7b4e7ea84a74/torchvision-0.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b4283d283675556bb0eae31d29996f53861b17cbdcdf3509e6bc050414ac9289", size = 1288328 },
1767
+ ]
1768
+
1769
  [[package]]
1770
  name = "tqdm"
1771
  version = "4.66.5"