added inference + api wrapper
Browse files- chemistral_api.py +200 -0
- example/Conformer3D_COMPOUND_CID_240.sdf +137 -0
- inference_transform.py +48 -0
- mistral_chat_script.py +38 -0
chemistral_api.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.middleware.cors import CORSMiddleware
|
2 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
3 |
+
from fastapi.responses import JSONResponse, FileResponse
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from typing import Optional
|
6 |
+
import subprocess
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
from inference_transform import process_smiles, process_pdb, process_sdf, extract_and_convert_to_sdf, is_valid_smiles
|
10 |
+
|
11 |
+
# Set up logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
app = FastAPI()
|
16 |
+
app.add_middleware(
|
17 |
+
CORSMiddleware,
|
18 |
+
allow_origins=['*'],
|
19 |
+
allow_credentials=True,
|
20 |
+
allow_methods=['*'],
|
21 |
+
allow_headers=['*']
|
22 |
+
)
|
23 |
+
|
24 |
+
sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
|
25 |
+
|
26 |
+
class InferenceRequest(BaseModel):
|
27 |
+
prompt: str
|
28 |
+
max_tokens: int = 256
|
29 |
+
temperature: float = 1.0
|
30 |
+
|
31 |
+
@app.post("/predict_base")
|
32 |
+
async def predict_base(
|
33 |
+
prompt: str = Form(...),
|
34 |
+
max_tokens: int = Form(256),
|
35 |
+
temperature: float = Form(1.0),
|
36 |
+
file: Optional[UploadFile] = File(None)
|
37 |
+
):
|
38 |
+
try:
|
39 |
+
if file:
|
40 |
+
file_path = f"/tmp/{file.filename}"
|
41 |
+
with open(file_path, "wb") as f:
|
42 |
+
f.write(file.file.read())
|
43 |
+
if file.filename.endswith(".pdb"):
|
44 |
+
prompt += f" {process_pdb(file_path)}"
|
45 |
+
elif file.filename.endswith(".sdf"):
|
46 |
+
prompt += f" {process_sdf(file_path)}"
|
47 |
+
else:
|
48 |
+
try:
|
49 |
+
sdf_file = extract_and_convert_to_sdf(prompt)
|
50 |
+
if sdf_file:
|
51 |
+
prompt += f" {sdf_file}"
|
52 |
+
except ValueError as e:
|
53 |
+
logger.info(str(e))
|
54 |
+
|
55 |
+
command = [
|
56 |
+
"python",
|
57 |
+
"/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
|
58 |
+
"/root/mistral_models/7B-v0.3/",
|
59 |
+
prompt,
|
60 |
+
f"--max_tokens={max_tokens}",
|
61 |
+
f"--temperature={temperature}",
|
62 |
+
"--instruct"
|
63 |
+
]
|
64 |
+
|
65 |
+
logger.info(f"Running command: {' '.join(command)}")
|
66 |
+
result = subprocess.run(command, capture_output=True, text=True)
|
67 |
+
|
68 |
+
if result.returncode != 0:
|
69 |
+
logger.error(f"Command failed with return code {result.returncode}")
|
70 |
+
logger.error(f"stderr: {result.stderr}")
|
71 |
+
raise HTTPException(status_code=500, detail=result.stderr)
|
72 |
+
|
73 |
+
response = result.stdout.strip()
|
74 |
+
sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
|
75 |
+
|
76 |
+
return {
|
77 |
+
"response": response,
|
78 |
+
"sdf_file_path": sdf_file_path
|
79 |
+
}
|
80 |
+
except Exception as e:
|
81 |
+
logger.exception("Exception occurred during inference.")
|
82 |
+
raise HTTPException(status_code=500, detail=str(e))
|
83 |
+
|
84 |
+
@app.post("/predict")
|
85 |
+
async def predict_alternative(
|
86 |
+
prompt: str = Form(...),
|
87 |
+
max_tokens: int = Form(256),
|
88 |
+
temperature: float = Form(1.0),
|
89 |
+
file: Optional[UploadFile] = File(None)
|
90 |
+
):
|
91 |
+
try:
|
92 |
+
if file:
|
93 |
+
file_path = f"/tmp/{file.filename}"
|
94 |
+
with open(file_path, "wb") as f:
|
95 |
+
f.write(await file.read())
|
96 |
+
if file.filename.endswith(".pdb"):
|
97 |
+
prompt += f" {process_pdb(file_path)}"
|
98 |
+
elif file.filename.endswith(".sdf"):
|
99 |
+
prompt += f" {process_sdf(file_path)}"
|
100 |
+
else:
|
101 |
+
try:
|
102 |
+
sdf_file = extract_and_convert_to_sdf(prompt)
|
103 |
+
if sdf_file:
|
104 |
+
prompt += f" {sdf_file}"
|
105 |
+
except ValueError as e:
|
106 |
+
logger.info(str(e))
|
107 |
+
|
108 |
+
command = [
|
109 |
+
"python",
|
110 |
+
"/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
|
111 |
+
"/root/mistral_models/7B-v0.3/",
|
112 |
+
prompt,
|
113 |
+
f"--max_tokens={max_tokens}",
|
114 |
+
f"--temperature={temperature}",
|
115 |
+
"--instruct",
|
116 |
+
"--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors"
|
117 |
+
]
|
118 |
+
logger.info(f"Running command: {' '.join(command)}")
|
119 |
+
result = subprocess.run(command, capture_output=True, text=True)
|
120 |
+
if result.returncode != 0:
|
121 |
+
logger.error(f"Command failed with return code {result.returncode}")
|
122 |
+
logger.error(f"stderr: {result.stderr}")
|
123 |
+
raise HTTPException(status_code=500, detail=result.stderr)
|
124 |
+
|
125 |
+
response = result.stdout.strip()
|
126 |
+
sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
|
127 |
+
|
128 |
+
# Return the file as a direct download
|
129 |
+
return FileResponse(sdf_file_path, media_type='chemical/x-mdl-sdfile', filename="Conformer3D_COMPOUND_CID_240.sdf")
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
logger.exception("Exception occurred during inference.")
|
133 |
+
raise HTTPException(status_code=500, detail=str(e))
|
134 |
+
|
135 |
+
# @app.post("/predict")
|
136 |
+
# async def predict_alternative(
|
137 |
+
# prompt: str = Form(...),
|
138 |
+
# max_tokens: int = Form(256),
|
139 |
+
# temperature: float = Form(1.0),
|
140 |
+
# file: Optional[UploadFile] = File(None)
|
141 |
+
# ):
|
142 |
+
# try:
|
143 |
+
# global sdf_file_path
|
144 |
+
# if file:
|
145 |
+
# file_path = f"/tmp/{file.filename}"
|
146 |
+
# with open(file_path, "wb") as f:
|
147 |
+
# f.write(file.file.read())
|
148 |
+
# if file.filename.endswith(".pdb"):
|
149 |
+
# prompt += f" {process_pdb(file_path)}"
|
150 |
+
# elif file.filename.endswith(".sdf"):
|
151 |
+
# prompt += f" {process_sdf(file_path)}"
|
152 |
+
# else:
|
153 |
+
# try:
|
154 |
+
# sdf_file = extract_and_convert_to_sdf(prompt)
|
155 |
+
# if sdf_file:
|
156 |
+
# prompt += f" {sdf_file}"
|
157 |
+
# except ValueError as e:
|
158 |
+
# logger.info(str(e))
|
159 |
+
|
160 |
+
# command = [
|
161 |
+
# "python",
|
162 |
+
# "/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
|
163 |
+
# "/root/mistral_models/7B-v0.3/",
|
164 |
+
# prompt,
|
165 |
+
# f"--max_tokens={max_tokens}",
|
166 |
+
# f"--temperature={temperature}",
|
167 |
+
# "--instruct",
|
168 |
+
# "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors"
|
169 |
+
# ]
|
170 |
+
|
171 |
+
# logger.info(f"Running command: {' '.join(command)}")
|
172 |
+
# result = subprocess.run(command, capture_output=True, text=True)
|
173 |
+
|
174 |
+
# if result.returncode != 0:
|
175 |
+
# logger.error(f"Command failed with return code {result.returncode}")
|
176 |
+
# logger.error(f"stderr: {result.stderr}")
|
177 |
+
# raise HTTPException(status_code=500, detail=result.stderr)
|
178 |
+
|
179 |
+
# response = result.stdout.strip()
|
180 |
+
# sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
|
181 |
+
|
182 |
+
# return {
|
183 |
+
# "response": response,
|
184 |
+
# "sdf_file_path": sdf_file_path
|
185 |
+
# }
|
186 |
+
# except Exception as e:
|
187 |
+
# logger.exception("Exception occurred during inference.")
|
188 |
+
# raise HTTPException(status_code=500, detail=str(e))
|
189 |
+
|
190 |
+
@app.get("/download_sdf")
|
191 |
+
async def download_sdf():
|
192 |
+
try:
|
193 |
+
return FileResponse(path=sdf_file_path, filename="Conformer3D_COMPOUND_CID_240.sdf")
|
194 |
+
except Exception as e:
|
195 |
+
logger.exception("Exception occurred while sending SDF file.")
|
196 |
+
raise HTTPException(status_code=500, detail=str(e))
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
import uvicorn
|
200 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
example/Conformer3D_COMPOUND_CID_240.sdf
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
240
|
2 |
+
-OEChem-03012409263D
|
3 |
+
|
4 |
+
14 14 0 0 0 0 0 0 0999 V2000
|
5 |
+
2.8466 -0.3870 0.0002 O 0 0 0 0 0 0 0 0 0 0 0 0
|
6 |
+
0.5644 0.2371 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
7 |
+
-0.3437 1.2960 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
8 |
+
0.1013 -1.0787 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
9 |
+
-1.7147 1.0393 0.0001 C 0 0 0 0 0 0 0 0 0 0 0 0
|
10 |
+
-1.2698 -1.3354 -0.0001 C 0 0 0 0 0 0 0 0 0 0 0 0
|
11 |
+
-2.1777 -0.2764 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
12 |
+
1.9937 0.5050 -0.0003 C 0 0 0 0 0 0 0 0 0 0 0 0
|
13 |
+
0.0016 2.3267 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
|
14 |
+
0.7902 -1.9194 -0.0001 H 0 0 0 0 0 0 0 0 0 0 0 0
|
15 |
+
-2.4218 1.8637 0.0001 H 0 0 0 0 0 0 0 0 0 0 0 0
|
16 |
+
-1.6308 -2.3599 -0.0001 H 0 0 0 0 0 0 0 0 0 0 0 0
|
17 |
+
-3.2452 -0.4764 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
|
18 |
+
2.2986 1.5653 -0.0006 H 0 0 0 0 0 0 0 0 0 0 0 0
|
19 |
+
1 8 2 0 0 0 0
|
20 |
+
2 3 2 0 0 0 0
|
21 |
+
2 4 1 0 0 0 0
|
22 |
+
2 8 1 0 0 0 0
|
23 |
+
3 5 1 0 0 0 0
|
24 |
+
3 9 1 0 0 0 0
|
25 |
+
4 6 2 0 0 0 0
|
26 |
+
4 10 1 0 0 0 0
|
27 |
+
5 7 2 0 0 0 0
|
28 |
+
5 11 1 0 0 0 0
|
29 |
+
6 7 1 0 0 0 0
|
30 |
+
6 12 1 0 0 0 0
|
31 |
+
7 13 1 0 0 0 0
|
32 |
+
8 14 1 0 0 0 0
|
33 |
+
M END
|
34 |
+
> <PUBCHEM_COMPOUND_CID>
|
35 |
+
240
|
36 |
+
|
37 |
+
> <PUBCHEM_CONFORMER_RMSD>
|
38 |
+
0.4
|
39 |
+
|
40 |
+
> <PUBCHEM_CONFORMER_DIVERSEORDER>
|
41 |
+
1
|
42 |
+
|
43 |
+
> <PUBCHEM_MMFF94_PARTIAL_CHARGES>
|
44 |
+
14
|
45 |
+
1 -0.57
|
46 |
+
10 0.15
|
47 |
+
11 0.15
|
48 |
+
12 0.15
|
49 |
+
13 0.15
|
50 |
+
14 0.06
|
51 |
+
2 0.09
|
52 |
+
3 -0.15
|
53 |
+
4 -0.15
|
54 |
+
5 -0.15
|
55 |
+
6 -0.15
|
56 |
+
7 -0.15
|
57 |
+
8 0.42
|
58 |
+
9 0.15
|
59 |
+
|
60 |
+
> <PUBCHEM_EFFECTIVE_ROTOR_COUNT>
|
61 |
+
1
|
62 |
+
|
63 |
+
> <PUBCHEM_PHARMACOPHORE_FEATURES>
|
64 |
+
2
|
65 |
+
1 1 acceptor
|
66 |
+
6 2 3 4 5 6 7 rings
|
67 |
+
|
68 |
+
> <PUBCHEM_HEAVY_ATOM_COUNT>
|
69 |
+
8
|
70 |
+
|
71 |
+
> <PUBCHEM_ATOM_DEF_STEREO_COUNT>
|
72 |
+
0
|
73 |
+
|
74 |
+
> <PUBCHEM_ATOM_UDEF_STEREO_COUNT>
|
75 |
+
0
|
76 |
+
|
77 |
+
> <PUBCHEM_BOND_DEF_STEREO_COUNT>
|
78 |
+
0
|
79 |
+
|
80 |
+
> <PUBCHEM_BOND_UDEF_STEREO_COUNT>
|
81 |
+
0
|
82 |
+
|
83 |
+
> <PUBCHEM_ISOTOPIC_ATOM_COUNT>
|
84 |
+
0
|
85 |
+
|
86 |
+
> <PUBCHEM_COMPONENT_COUNT>
|
87 |
+
1
|
88 |
+
|
89 |
+
> <PUBCHEM_CACTVS_TAUTO_COUNT>
|
90 |
+
1
|
91 |
+
|
92 |
+
> <PUBCHEM_CONFORMER_ID>
|
93 |
+
000000F000000001
|
94 |
+
|
95 |
+
> <PUBCHEM_MMFF94_ENERGY>
|
96 |
+
18.0728
|
97 |
+
|
98 |
+
> <PUBCHEM_FEATURE_SELFOVERLAP>
|
99 |
+
10.148
|
100 |
+
|
101 |
+
> <PUBCHEM_SHAPE_FINGERPRINT>
|
102 |
+
16714656 1 18409731763581766061
|
103 |
+
18185500 45 18263078975662380655
|
104 |
+
21040471 1 18338797793165636005
|
105 |
+
23552423 10 18187929525178951646
|
106 |
+
29004967 10 16200157598908099156
|
107 |
+
369184 2 15195566792725449510
|
108 |
+
5084963 1 18412544318583771616
|
109 |
+
|
110 |
+
> <PUBCHEM_SHAPE_MULTIPOLES>
|
111 |
+
158.77
|
112 |
+
3.12
|
113 |
+
1.41
|
114 |
+
0.6
|
115 |
+
1.61
|
116 |
+
0.02
|
117 |
+
0
|
118 |
+
0.14
|
119 |
+
0
|
120 |
+
-0.45
|
121 |
+
0
|
122 |
+
-0.03
|
123 |
+
-0.01
|
124 |
+
0
|
125 |
+
|
126 |
+
> <PUBCHEM_SHAPE_SELFOVERLAP>
|
127 |
+
329.455
|
128 |
+
|
129 |
+
> <PUBCHEM_SHAPE_VOLUME>
|
130 |
+
90.5
|
131 |
+
|
132 |
+
> <PUBCHEM_COORDINATE_TYPE>
|
133 |
+
2
|
134 |
+
5
|
135 |
+
10
|
136 |
+
|
137 |
+
$$$$
|
inference_transform.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from rdkit import Chem
|
3 |
+
from rdkit.Chem import MolFromSmiles, SDWriter
|
4 |
+
import logging
|
5 |
+
from Bio import SeqIO
|
6 |
+
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
def process_smiles(smiles: str) -> str:
|
12 |
+
mol = MolFromSmiles(smiles)
|
13 |
+
if not mol:
|
14 |
+
raise ValueError(f"Invalid SMILES string: {smiles}")
|
15 |
+
|
16 |
+
sdf_file = "/tmp/output.sdf"
|
17 |
+
writer = SDWriter(sdf_file)
|
18 |
+
writer.write(mol)
|
19 |
+
writer.close()
|
20 |
+
|
21 |
+
return sdf_file
|
22 |
+
|
23 |
+
def process_pdb(file_path: str) -> str:
|
24 |
+
sequences = []
|
25 |
+
with open(file_path, "r") as handle:
|
26 |
+
for record in SeqIO.parse(handle, "pdb-seqres"):
|
27 |
+
sequences.append(str(record.seq))
|
28 |
+
return " ".join(sequences)
|
29 |
+
|
30 |
+
def process_sdf(file_path: str) -> str:
|
31 |
+
return file_path
|
32 |
+
|
33 |
+
def extract_smiles(text: str) -> str:
|
34 |
+
smiles_pattern = r"([^J][0-9BCOHNSOPrIFla@+\-\[\]\(\)\\\/%=#$]{6,})"
|
35 |
+
matches = re.findall(smiles_pattern, text)
|
36 |
+
if matches:
|
37 |
+
return matches[0]
|
38 |
+
return ""
|
39 |
+
|
40 |
+
def is_valid_smiles(smiles: str) -> bool:
|
41 |
+
mol = MolFromSmiles(smiles)
|
42 |
+
return mol is not None
|
43 |
+
|
44 |
+
def extract_and_convert_to_sdf(text: str) -> str:
|
45 |
+
smiles = extract_smiles(text)
|
46 |
+
if smiles and is_valid_smiles(smiles):
|
47 |
+
return process_smiles(smiles)
|
48 |
+
raise ValueError("No valid SMILES string found in the text.")
|
mistral_chat_script.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
from mistral_inference.generate import generate
|
4 |
+
from mistral_inference.model import Transformer
|
5 |
+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
6 |
+
|
7 |
+
def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None):
|
8 |
+
# Find the correct tokenizer file
|
9 |
+
model_path = Path(model_path)
|
10 |
+
tokenizer_file = model_path / "tokenizer.model.v3"
|
11 |
+
|
12 |
+
if not tokenizer_file.is_file():
|
13 |
+
raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}")
|
14 |
+
|
15 |
+
mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file))
|
16 |
+
tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
|
17 |
+
|
18 |
+
transformer = Transformer.from_folder(
|
19 |
+
model_path, max_batch_size=3, num_pipeline_ranks=1
|
20 |
+
)
|
21 |
+
|
22 |
+
if lora_path is not None:
|
23 |
+
transformer.load_lora(Path(lora_path))
|
24 |
+
|
25 |
+
tokens = tokenizer.encode(prompt, bos=True, eos=False)
|
26 |
+
generated_tokens, _ = generate(
|
27 |
+
[tokens],
|
28 |
+
transformer,
|
29 |
+
max_tokens=max_tokens,
|
30 |
+
temperature=temperature,
|
31 |
+
eos_id=tokenizer.eos_id,
|
32 |
+
)
|
33 |
+
answer = tokenizer.decode(generated_tokens[0])
|
34 |
+
print(answer)
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
import fire
|
38 |
+
fire.Fire(run_chat)
|