adapt classifier to outlines v1.0.4
Browse files- classifier/classifier.py +27 -15
- classifier/requirements.txt +2 -1
classifier/classifier.py
CHANGED
@@ -2,11 +2,12 @@ import logging
|
|
2 |
import os
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
-
from typing import List
|
|
|
6 |
|
7 |
import outlines
|
8 |
from outlines.models import openai
|
9 |
-
from outlines.
|
10 |
|
11 |
# Configure logger
|
12 |
logging.basicConfig(level=logging.DEBUG)
|
@@ -15,7 +16,9 @@ logger = logging.getLogger("classifier")
|
|
15 |
app = FastAPI()
|
16 |
|
17 |
# Global variables for shared config and classifier
|
18 |
-
|
|
|
|
|
19 |
config_set = False
|
20 |
|
21 |
class Config(BaseModel):
|
@@ -33,7 +36,7 @@ class Resp(BaseModel):
|
|
33 |
@app.post("/config")
|
34 |
def configure(req: Config):
|
35 |
"""Receive and initialize classifier configuration."""
|
36 |
-
global
|
37 |
|
38 |
if config_set:
|
39 |
logger.warning("Classifier already configured. Ignoring new config.")
|
@@ -43,11 +46,12 @@ def configure(req: Config):
|
|
43 |
logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
|
44 |
|
45 |
try:
|
|
|
46 |
llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
|
47 |
-
|
48 |
-
|
49 |
-
clf.prompt_template = req.prompt_template
|
50 |
config_set = True
|
|
|
51 |
logger.info("Classifier configured successfully.")
|
52 |
return {"status": "configured"}
|
53 |
except Exception as e:
|
@@ -56,24 +60,32 @@ def configure(req: Config):
|
|
56 |
|
57 |
@app.post("/classify", response_model=Resp)
|
58 |
def classify(req: Req):
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
raise HTTPException(status_code=503, detail="Classifier not configured yet")
|
62 |
|
63 |
-
# Render
|
64 |
try:
|
65 |
-
prompt =
|
66 |
logger.debug(f"Rendered prompt: {prompt!r}")
|
67 |
except Exception as e:
|
68 |
logger.warning(f"Prompt rendering failed: {e}")
|
69 |
prompt = req.message
|
70 |
|
71 |
-
#
|
72 |
try:
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
logger.debug(f"Classification result: {result}")
|
75 |
except Exception as e:
|
76 |
-
logger.error(f"Classification error: {e}. Falling back to: {
|
77 |
-
result =
|
78 |
|
79 |
return Resp(result=result)
|
|
|
2 |
import os
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
+
from typing import List
|
6 |
+
import json
|
7 |
|
8 |
import outlines
|
9 |
from outlines.models import openai
|
10 |
+
from outlines.types import Choice
|
11 |
|
12 |
# Configure logger
|
13 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
16 |
app = FastAPI()
|
17 |
|
18 |
# Global variables for shared config and classifier
|
19 |
+
llm = None
|
20 |
+
clf_output = None
|
21 |
+
prompt_template = None
|
22 |
config_set = False
|
23 |
|
24 |
class Config(BaseModel):
|
|
|
36 |
@app.post("/config")
|
37 |
def configure(req: Config):
|
38 |
"""Receive and initialize classifier configuration."""
|
39 |
+
global llm, clf_output, prompt_template, config_set
|
40 |
|
41 |
if config_set:
|
42 |
logger.warning("Classifier already configured. Ignoring new config.")
|
|
|
46 |
logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
|
47 |
|
48 |
try:
|
49 |
+
# Instantiate model and choice output type
|
50 |
llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
|
51 |
+
clf_output = Choice(req.class_set)
|
52 |
+
prompt_template = req.prompt_template
|
|
|
53 |
config_set = True
|
54 |
+
|
55 |
logger.info("Classifier configured successfully.")
|
56 |
return {"status": "configured"}
|
57 |
except Exception as e:
|
|
|
60 |
|
61 |
@app.post("/classify", response_model=Resp)
|
62 |
def classify(req: Req):
|
63 |
+
"""Run text classification using the configured LLM."""
|
64 |
+
global llm, clf_output, prompt_template, config_set
|
65 |
+
|
66 |
+
if not config_set or llm is None or clf_output is None:
|
67 |
raise HTTPException(status_code=503, detail="Classifier not configured yet")
|
68 |
|
69 |
+
# Render prompt
|
70 |
try:
|
71 |
+
prompt = prompt_template.replace("{message}", req.message)
|
72 |
logger.debug(f"Rendered prompt: {prompt!r}")
|
73 |
except Exception as e:
|
74 |
logger.warning(f"Prompt rendering failed: {e}")
|
75 |
prompt = req.message
|
76 |
|
77 |
+
# Invoke LLM classifier
|
78 |
try:
|
79 |
+
raw = llm(prompt, clf_output)
|
80 |
+
try:
|
81 |
+
# llm returned '{"result": "Label"}' → parse it
|
82 |
+
unwrapped = json.loads(raw).get("result", raw)
|
83 |
+
except Exception:
|
84 |
+
unwrapped = raw
|
85 |
+
result = unwrapped
|
86 |
logger.debug(f"Classification result: {result}")
|
87 |
except Exception as e:
|
88 |
+
logger.error(f"Classification error: {e}. Falling back to: {req.class_set[-1]}")
|
89 |
+
result = req.class_set[-1]
|
90 |
|
91 |
return Resp(result=result)
|
classifier/requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
fastapi
|
2 |
uvicorn
|
3 |
-
outlines
|
|
|
4 |
|
5 |
# Force CPU-only torch
|
6 |
torch==2.0.1+cpu
|
|
|
1 |
fastapi
|
2 |
uvicorn
|
3 |
+
outlines
|
4 |
+
numpy
|
5 |
|
6 |
# Force CPU-only torch
|
7 |
torch==2.0.1+cpu
|