Luigi commited on
Commit
9ffbce4
·
1 Parent(s): ef8c7e3

adapt classifier to outlines v1.0.4

Browse files
classifier/classifier.py CHANGED
@@ -2,11 +2,12 @@ import logging
2
  import os
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from typing import List, Optional
 
6
 
7
  import outlines
8
  from outlines.models import openai
9
- from outlines.generate import choice
10
 
11
  # Configure logger
12
  logging.basicConfig(level=logging.DEBUG)
@@ -15,7 +16,9 @@ logger = logging.getLogger("classifier")
15
  app = FastAPI()
16
 
17
  # Global variables for shared config and classifier
18
- clf = None
 
 
19
  config_set = False
20
 
21
  class Config(BaseModel):
@@ -33,7 +36,7 @@ class Resp(BaseModel):
33
  @app.post("/config")
34
  def configure(req: Config):
35
  """Receive and initialize classifier configuration."""
36
- global clf, config_set
37
 
38
  if config_set:
39
  logger.warning("Classifier already configured. Ignoring new config.")
@@ -43,11 +46,12 @@ def configure(req: Config):
43
  logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
44
 
45
  try:
 
46
  llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
47
- clf = choice(llm, req.class_set)
48
- clf.class_set = req.class_set
49
- clf.prompt_template = req.prompt_template
50
  config_set = True
 
51
  logger.info("Classifier configured successfully.")
52
  return {"status": "configured"}
53
  except Exception as e:
@@ -56,24 +60,32 @@ def configure(req: Config):
56
 
57
  @app.post("/classify", response_model=Resp)
58
  def classify(req: Req):
59
- global clf
60
- if clf is None or not config_set:
 
 
61
  raise HTTPException(status_code=503, detail="Classifier not configured yet")
62
 
63
- # Render the prompt using the template
64
  try:
65
- prompt = clf.prompt_template.replace("{message}", req.message)
66
  logger.debug(f"Rendered prompt: {prompt!r}")
67
  except Exception as e:
68
  logger.warning(f"Prompt rendering failed: {e}")
69
  prompt = req.message
70
 
71
- # Run classifier
72
  try:
73
- result = clf(prompt)
 
 
 
 
 
 
74
  logger.debug(f"Classification result: {result}")
75
  except Exception as e:
76
- logger.error(f"Classification error: {e}. Falling back to: {clf.class_set[-1]}")
77
- result = clf.class_set[-1]
78
 
79
  return Resp(result=result)
 
2
  import os
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
+ from typing import List
6
+ import json
7
 
8
  import outlines
9
  from outlines.models import openai
10
+ from outlines.types import Choice
11
 
12
  # Configure logger
13
  logging.basicConfig(level=logging.DEBUG)
 
16
  app = FastAPI()
17
 
18
  # Global variables for shared config and classifier
19
+ llm = None
20
+ clf_output = None
21
+ prompt_template = None
22
  config_set = False
23
 
24
  class Config(BaseModel):
 
36
  @app.post("/config")
37
  def configure(req: Config):
38
  """Receive and initialize classifier configuration."""
39
+ global llm, clf_output, prompt_template, config_set
40
 
41
  if config_set:
42
  logger.warning("Classifier already configured. Ignoring new config.")
 
46
  logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
47
 
48
  try:
49
+ # Instantiate model and choice output type
50
  llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
51
+ clf_output = Choice(req.class_set)
52
+ prompt_template = req.prompt_template
 
53
  config_set = True
54
+
55
  logger.info("Classifier configured successfully.")
56
  return {"status": "configured"}
57
  except Exception as e:
 
60
 
61
  @app.post("/classify", response_model=Resp)
62
  def classify(req: Req):
63
+ """Run text classification using the configured LLM."""
64
+ global llm, clf_output, prompt_template, config_set
65
+
66
+ if not config_set or llm is None or clf_output is None:
67
  raise HTTPException(status_code=503, detail="Classifier not configured yet")
68
 
69
+ # Render prompt
70
  try:
71
+ prompt = prompt_template.replace("{message}", req.message)
72
  logger.debug(f"Rendered prompt: {prompt!r}")
73
  except Exception as e:
74
  logger.warning(f"Prompt rendering failed: {e}")
75
  prompt = req.message
76
 
77
+ # Invoke LLM classifier
78
  try:
79
+ raw = llm(prompt, clf_output)
80
+ try:
81
+ # llm returned '{"result": "Label"}' → parse it
82
+ unwrapped = json.loads(raw).get("result", raw)
83
+ except Exception:
84
+ unwrapped = raw
85
+ result = unwrapped
86
  logger.debug(f"Classification result: {result}")
87
  except Exception as e:
88
+ logger.error(f"Classification error: {e}. Falling back to: {req.class_set[-1]}")
89
+ result = req.class_set[-1]
90
 
91
  return Resp(result=result)
classifier/requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  fastapi
2
  uvicorn
3
- outlines==0.2.1
 
4
 
5
  # Force CPU-only torch
6
  torch==2.0.1+cpu
 
1
  fastapi
2
  uvicorn
3
+ outlines
4
+ numpy
5
 
6
  # Force CPU-only torch
7
  torch==2.0.1+cpu