bambadij commited on
Commit
bbf7982
·
1 Parent(s): 003aab4
Files changed (1) hide show
  1. app.py +20 -71
app.py CHANGED
@@ -4,10 +4,7 @@ from pydantic import BaseModel
4
  import torch
5
  from transformers import (
6
  AutoModelForCausalLM,
7
- AutoTokenizer,
8
- StoppingCriteria,
9
- StoppingCriteriaList,
10
- TextIteratorStreamer
11
  )
12
  from typing import List, Tuple
13
  from threading import Thread
@@ -38,19 +35,6 @@ app =FastAPI(
38
  title='Text Summary',
39
  description =Informations
40
  )
41
-
42
- #class to define the input text
43
- logging.basicConfig(level=logging.INFO)
44
- logger =logging.getLogger(__name__)
45
-
46
- class StopOnTokens(StoppingCriteria):
47
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
48
- stop_ids = model.config.eos_token_id
49
- for stop_id in stop_ids:
50
- if input_ids[0][-1] == stop_id:
51
- return True
52
- return False
53
-
54
  default_prompt = """Bonjour,
55
 
56
  En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
@@ -63,64 +47,29 @@ Ajoute une recommandation importante pour éviter le mécontentement du client,
63
  Merci !
64
 
65
  """
 
 
 
 
66
  class PredictionRequest(BaseModel):
67
- history: List[Tuple[str, str]] = []
68
- prompt: str = default_prompt
69
- max_length: int = 10240
70
- top_p: float = 0.8
71
- temperature: float = 0.6
72
  @app.post("/predict/")
73
  async def predict(request: PredictionRequest):
74
- history = request.history
75
- prompt = request.prompt
76
- max_length = request.max_length
77
- top_p = request.top_p
78
- temperature = request.temperature
79
-
80
- stop = StopOnTokens()
81
- messages = []
82
- if prompt:
83
- messages.append({"role": "system", "content": prompt})
84
- for idx, (user_msg, model_msg) in enumerate(history):
85
- if prompt and idx == 0:
86
- continue
87
- if idx == len(history) - 1 and not model_msg:
88
- query = user_msg
89
- break
90
- if user_msg:
91
- messages.append({"role": "user", "content": user_msg})
92
- if model_msg:
93
- messages.append({"role": "assistant", "content": model_msg})
94
-
95
- model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
96
- next(model.parameters()).device)
97
- streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
98
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
99
- tokenizer.get_command("<|observation|>")]
100
- generate_kwargs = {
101
- "input_ids": model_inputs,
102
- "streamer": streamer,
103
- "max_new_tokens": max_length,
104
- "do_sample": True,
105
- "top_p": top_p,
106
- "temperature": temperature,
107
- "stopping_criteria": StoppingCriteriaList([stop]),
108
- "repetition_penalty": 1,
109
- "eos_token_id": eos_token_id,
110
- }
111
-
112
- t = Thread(target=model.generate, kwargs=generate_kwargs)
113
- t.start()
114
-
115
- generated_text = ""
116
- for new_token in streamer:
117
- if new_token and '<|user|>' in new_token:
118
- new_token = new_token.split('<|user|>')[0]
119
- if new_token:
120
- generated_text += new_token
121
- history[-1][1] = generated_text
122
 
123
- return {"history": history}
124
  if __name__ == "__main__":
125
  uvicorn.run("app:app",reload=True)
126
 
 
4
  import torch
5
  from transformers import (
6
  AutoModelForCausalLM,
7
+ AutoTokenizer
 
 
 
8
  )
9
  from typing import List, Tuple
10
  from threading import Thread
 
35
  title='Text Summary',
36
  description =Informations
37
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  default_prompt = """Bonjour,
39
 
40
  En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
 
47
  Merci !
48
 
49
  """
50
+ #class to define the input text
51
+ logging.basicConfig(level=logging.INFO)
52
+ logger =logging.getLogger(__name__)
53
+ # Définir le modèle de requête
54
  class PredictionRequest(BaseModel):
55
+ text: str = None # Texte personnalisé ajouté par l'utilisateur
56
+ max_length: int = 512 # Limite la longueur maximale du texte généré
57
+
 
 
58
  @app.post("/predict/")
59
  async def predict(request: PredictionRequest):
60
+ # Construire le prompt final
61
+ if request.text:
62
+ prompt = default_prompt + "\n\n" + request.text
63
+ else:
64
+ prompt = default_prompt
65
+
66
+ # Générer le texte à partir du prompt
67
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
68
+ outputs = model.generate(inputs, max_length=request.max_length, do_sample=True)
69
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+
71
+ return {"generated_text": generated_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
73
  if __name__ == "__main__":
74
  uvicorn.run("app:app",reload=True)
75