bambadij commited on
Commit
003aab4
·
1 Parent(s): f26e8b8
Files changed (1) hide show
  1. app.py +77 -48
app.py CHANGED
@@ -1,5 +1,5 @@
1
  #load package
2
- from fastapi import FastAPI,HTTPException
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
@@ -24,7 +24,7 @@ os.environ['HF_HOME'] = '/app/.cache'
24
  model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
25
  tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
26
 
27
-
28
  #Additional information
29
 
30
  Informations = """
@@ -43,57 +43,86 @@ app =FastAPI(
43
  logging.basicConfig(level=logging.INFO)
44
  logger =logging.getLogger(__name__)
45
 
46
- @app.get("/")
47
- async def home():
48
- return 'STN BIG DATA'
49
-
50
- # Charger le modèle et le tokenizer
51
- # model_name = "THUDM/longwriter-glm4-9b"
52
- # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
53
- # model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
54
-
55
- # Charger le modèle et le tokenizer
56
- model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device_map='auto')
57
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
58
 
59
- # Le reste de votre code reste le même
60
 
61
- # Prompt par défaut
62
- default_prompt = """Vous êtes un assistant expert en résumé de plaintes. Votre tâche est de résumer la plainte fournie de manière concise et professionnelle, en incluant les points clés suivants :
 
 
63
 
64
- 1. Le problème principal
65
- 2. Les détails pertinents
66
- 3. L'impact sur le plaignant
67
- 4. Toute action ou résolution demandée
68
 
69
- Résumez la plainte suivante en 3-4 phrases :
70
 
71
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- class ComplaintInput(BaseModel):
74
- text: str
75
-
76
- @app.post("/summarize_complaint")
77
- async def summarize_complaint(input: ComplaintInput):
78
- try:
79
- full_prompt = default_prompt + input.text
80
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
81
-
82
- with torch.no_grad():
83
- outputs = model.generate(
84
- **inputs,
85
- max_new_tokens=150,
86
- num_return_sequences=1,
87
- no_repeat_ngram_size=2,
88
- temperature=0.7
89
- )
90
-
91
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
92
- # Enlever le prompt initial de la sortie
93
- summary = summary.replace(full_prompt, "").strip()
94
- return {"summary": summary}
95
- except Exception as e:
96
- raise HTTPException(status_code=500, detail=str(e))
97
 
98
- if __name__ == "__main__":
99
- uvicorn.run("app:app",reload=True)
 
1
  #load package
2
+ from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
 
24
  model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
25
  tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
26
 
27
+
28
  #Additional information
29
 
30
  Informations = """
 
43
  logging.basicConfig(level=logging.INFO)
44
  logger =logging.getLogger(__name__)
45
 
46
+ class StopOnTokens(StoppingCriteria):
47
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
48
+ stop_ids = model.config.eos_token_id
49
+ for stop_id in stop_ids:
50
+ if input_ids[0][-1] == stop_id:
51
+ return True
52
+ return False
 
 
 
 
 
53
 
54
+ default_prompt = """Bonjour,
55
 
56
+ En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
57
+ 1. **Informations Client** : Indique des détails pertinents sur le client.
58
+ 2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.).
59
+ 3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints).
60
 
61
+ Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant.
 
 
 
62
 
63
+ Merci !
64
 
65
  """
66
+ class PredictionRequest(BaseModel):
67
+ history: List[Tuple[str, str]] = []
68
+ prompt: str = default_prompt
69
+ max_length: int = 10240
70
+ top_p: float = 0.8
71
+ temperature: float = 0.6
72
+ @app.post("/predict/")
73
+ async def predict(request: PredictionRequest):
74
+ history = request.history
75
+ prompt = request.prompt
76
+ max_length = request.max_length
77
+ top_p = request.top_p
78
+ temperature = request.temperature
79
+
80
+ stop = StopOnTokens()
81
+ messages = []
82
+ if prompt:
83
+ messages.append({"role": "system", "content": prompt})
84
+ for idx, (user_msg, model_msg) in enumerate(history):
85
+ if prompt and idx == 0:
86
+ continue
87
+ if idx == len(history) - 1 and not model_msg:
88
+ query = user_msg
89
+ break
90
+ if user_msg:
91
+ messages.append({"role": "user", "content": user_msg})
92
+ if model_msg:
93
+ messages.append({"role": "assistant", "content": model_msg})
94
+
95
+ model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
96
+ next(model.parameters()).device)
97
+ streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
98
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
99
+ tokenizer.get_command("<|observation|>")]
100
+ generate_kwargs = {
101
+ "input_ids": model_inputs,
102
+ "streamer": streamer,
103
+ "max_new_tokens": max_length,
104
+ "do_sample": True,
105
+ "top_p": top_p,
106
+ "temperature": temperature,
107
+ "stopping_criteria": StoppingCriteriaList([stop]),
108
+ "repetition_penalty": 1,
109
+ "eos_token_id": eos_token_id,
110
+ }
111
+
112
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
113
+ t.start()
114
+
115
+ generated_text = ""
116
+ for new_token in streamer:
117
+ if new_token and '<|user|>' in new_token:
118
+ new_token = new_token.split('<|user|>')[0]
119
+ if new_token:
120
+ generated_text += new_token
121
+ history[-1][1] = generated_text
122
+
123
+ return {"history": history}
124
+ if __name__ == "__main__":
125
+ uvicorn.run("app:app",reload=True)
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+