Kuberwastaken commited on
Commit
2b6c9d9
·
1 Parent(s): 45cb4ff

Increased Model Efficiency

Browse files
Files changed (1) hide show
  1. model/analyzer.py +68 -118
model/analyzer.py CHANGED
@@ -1,13 +1,13 @@
1
  import os
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datetime import datetime
5
  import gradio as gr
6
  from typing import Dict, List, Union, Optional
7
  import logging
8
  import traceback
 
9
 
10
- # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
@@ -19,63 +19,62 @@ class ContentAnalyzer:
19
  logger.info(f"Initialized analyzer with device: {self.device}")
20
 
21
  async def load_model(self, progress=None) -> None:
22
- """Load the model and tokenizer with progress updates and detailed logging."""
23
  try:
24
- print("\n=== Starting Model Loading ===")
25
- print(f"Time: {datetime.now()}")
26
-
27
  if progress:
28
  progress(0.1, "Loading tokenizer...")
29
 
30
- print("Loading tokenizer...")
 
 
 
 
 
 
31
  self.tokenizer = AutoTokenizer.from_pretrained(
32
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
33
  use_fast=True
34
  )
35
 
36
  if progress:
37
- progress(0.3, "Loading model...")
38
-
39
- print(f"Loading model on {self.device}...")
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
42
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
43
  device_map="auto"
44
  )
45
 
46
  if progress:
47
  progress(0.5, "Model loaded successfully")
48
-
49
- print("Model and tokenizer loaded successfully")
50
- logger.info(f"Model loaded successfully on {self.device}")
51
  except Exception as e:
52
- logger.error(f"Error loading model: {str(e)}")
53
- print(f"\nERROR DURING MODEL LOADING: {str(e)}")
54
- print("Stack trace:")
55
  traceback.print_exc()
56
  raise
57
 
58
- def _chunk_text(self, text: str, chunk_size: int = 2048, overlap: int = 256) -> List[str]:
59
- """Split text into overlapping chunks for processing."""
60
  chunks = []
61
- for i in range(0, len(text), chunk_size - overlap):
62
- chunk = text[i:i + chunk_size]
63
- chunks.append(chunk)
64
- print(f"Split text into {len(chunks)} chunks with {overlap} token overlap")
 
 
 
 
 
 
 
65
  return chunks
66
 
67
  async def analyze_chunk(
68
- self,
69
- chunk: str,
70
- progress: Optional[gr.Progress] = None,
71
- current_progress: float = 0,
72
- progress_step: float = 0
73
  ) -> List[str]:
74
- """Analyze a single chunk of text for triggers with detailed logging."""
75
- print(f"\n--- Processing Chunk ---")
76
- print(f"Chunk text (preview): {chunk[:50]}...")
77
-
78
- # Comprehensive trigger categories
79
  categories = [
80
  "Violence", "Death", "Substance Use", "Gore",
81
  "Vomit", "Sexual Content", "Sexual Abuse",
@@ -83,111 +82,67 @@ class ContentAnalyzer:
83
  "Mental Health Issues"
84
  ]
85
 
86
- # Comprehensive prompt for single-pass analysis
87
- prompt = f"""Comprehensive Content Sensitivity Analysis
88
-
89
- Carefully analyze the following text for sensitive content categories:
90
- {', '.join(categories)}
91
-
92
- Detailed Requirements:
93
- 1. Thoroughly examine entire text chunk
94
- 2. Identify presence of ANY of these categories
95
- 3. Provide clear, objective assessment
96
- 4. Minimal subjective interpretation
97
-
98
- TEXT CHUNK:
99
- {chunk}
100
-
101
- RESPONSE FORMAT:
102
- - List categories DEFINITIVELY present
103
- - Brief objective justification for each
104
- - Strict YES/NO categorization"""
105
 
106
  try:
107
- print("Sending prompt to model...")
108
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
109
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
110
 
111
- with torch.no_grad():
112
- print("Generating response...")
113
- outputs = self.model.generate(
114
- **inputs,
115
- max_new_tokens=256,
116
- do_sample=True,
117
- temperature=0.2,
118
- top_p=0.9,
119
- pad_token_id=self.tokenizer.eos_token_id
120
- )
121
-
122
- response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
123
- print("Full Model Response:", response_text)
124
-
125
- # Parse detected triggers
126
- detected_triggers = []
127
- for category in categories:
128
- if category.upper() in response_text.upper():
129
- detected_triggers.append(category)
130
-
131
- print(f"Detected triggers in chunk: {detected_triggers}")
132
 
133
- if progress:
134
- current_progress += progress_step
135
- progress(min(current_progress, 0.9), "Analyzing chunk...")
 
 
 
 
136
 
137
- return detected_triggers
138
 
139
  except Exception as e:
140
- logger.error(f"Error analyzing chunk: {str(e)}")
141
- print(f"Error during chunk analysis: {str(e)}")
142
- traceback.print_exc()
143
  return []
144
 
145
  async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
146
- """Analyze the entire script for triggers with progress updates."""
147
- print("\n=== Starting Script Analysis ===")
148
- print(f"Time: {datetime.now()}")
149
-
150
  if not self.model or not self.tokenizer:
151
  await self.load_model(progress)
152
 
153
- chunks = self._chunk_text(script)
154
- identified_triggers = set()
155
- progress_step = 0.4 / len(chunks)
156
- current_progress = 0.5 # Starting after model loading
157
-
158
- for chunk_idx, chunk in enumerate(chunks, 1):
159
- chunk_triggers = await self.analyze_chunk(
160
- chunk,
161
- progress,
162
- current_progress,
163
- progress_step
164
- )
165
- identified_triggers.update(chunk_triggers)
166
-
167
- if progress:
168
- progress(0.95, "Finalizing results...")
169
 
170
- final_triggers = list(identified_triggers)
171
- print("\n=== Analysis Complete ===")
172
- print("Final Results:", final_triggers)
 
 
 
173
 
174
- return final_triggers if final_triggers else ["None"]
175
 
176
  async def analyze_content(
177
  script: str,
178
  progress: Optional[gr.Progress] = None
179
  ) -> Dict[str, Union[List[str], str]]:
180
- """Main analysis function for the Gradio interface."""
181
- print("\n=== Starting Content Analysis ===")
182
- print(f"Time: {datetime.now()}")
183
-
184
  analyzer = ContentAnalyzer()
185
 
186
  try:
187
  triggers = await analyzer.analyze_script(script, progress)
188
-
189
- if progress:
190
- progress(1.0, "Analysis complete!")
191
 
192
  result = {
193
  "detected_triggers": triggers,
@@ -196,14 +151,10 @@ async def analyze_content(
196
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
197
  }
198
 
199
- print("\nFinal Result Dictionary:", result)
200
  return result
201
 
202
  except Exception as e:
203
  logger.error(f"Analysis error: {str(e)}")
204
- print(f"\nERROR OCCURRED: {str(e)}")
205
- print("Stack trace:")
206
- traceback.print_exc()
207
  return {
208
  "detected_triggers": ["Error occurred during analysis"],
209
  "confidence": "Error",
@@ -213,7 +164,6 @@ async def analyze_content(
213
  }
214
 
215
  if __name__ == "__main__":
216
- # Gradio interface
217
  iface = gr.Interface(
218
  fn=analyze_content,
219
  inputs=gr.Textbox(lines=8, label="Input Text"),
 
1
  import os
2
+ import asyncio
3
  import torch
4
  from datetime import datetime
5
  import gradio as gr
6
  from typing import Dict, List, Union, Optional
7
  import logging
8
  import traceback
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
19
  logger.info(f"Initialized analyzer with device: {self.device}")
20
 
21
  async def load_model(self, progress=None) -> None:
22
+ """Load quantized model with optimized configuration."""
23
  try:
 
 
 
24
  if progress:
25
  progress(0.1, "Loading tokenizer...")
26
 
27
+ # Quantization configuration
28
+ quantization_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_compute_dtype=torch.float16,
31
+ bnb_4bit_quant_type="nf4"
32
+ )
33
+
34
  self.tokenizer = AutoTokenizer.from_pretrained(
35
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
36
  use_fast=True
37
  )
38
 
39
  if progress:
40
+ progress(0.3, "Loading quantized model...")
41
+
 
42
  self.model = AutoModelForCausalLM.from_pretrained(
43
  "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
44
+ quantization_config=quantization_config,
45
  device_map="auto"
46
  )
47
 
48
  if progress:
49
  progress(0.5, "Model loaded successfully")
50
+
 
 
51
  except Exception as e:
52
+ logger.error(f"Model loading error: {str(e)}")
 
 
53
  traceback.print_exc()
54
  raise
55
 
56
+ def _semantic_chunk_text(self, text: str, max_chunk_size: int = 4096) -> List[str]:
57
+ """Semantic chunking with dynamic sizing."""
58
  chunks = []
59
+ current_chunk = ""
60
+ for sentence in text.split('.'):
61
+ if len(current_chunk) + len(sentence) < max_chunk_size:
62
+ current_chunk += sentence + '.'
63
+ else:
64
+ chunks.append(current_chunk.strip())
65
+ current_chunk = sentence + '.'
66
+
67
+ if current_chunk:
68
+ chunks.append(current_chunk.strip())
69
+
70
  return chunks
71
 
72
  async def analyze_chunk(
73
+ self,
74
+ chunk: str,
75
+ progress: Optional[gr.Progress] = None
 
 
76
  ) -> List[str]:
77
+ """Optimized single-pass chunk analysis."""
 
 
 
 
78
  categories = [
79
  "Violence", "Death", "Substance Use", "Gore",
80
  "Vomit", "Sexual Content", "Sexual Abuse",
 
82
  "Mental Health Issues"
83
  ]
84
 
85
+ prompt = f"""Analyze this text for sensitive content.
86
+ Categories: {', '.join(categories)}
87
+ Identify ALL present categories.
88
+ Be precise and direct.
89
+ Chunk: {chunk}
90
+ Output Format: Comma-separated category names if present."""
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  try:
93
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
 
94
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
95
 
96
+ outputs = self.model.generate(
97
+ **inputs,
98
+ max_new_tokens=128,
99
+ do_sample=True,
100
+ temperature=0.2,
101
+ top_p=0.9,
102
+ pad_token_id=self.tokenizer.eos_token_id
103
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
106
+
107
+ # Extract detected categories
108
+ detected = [
109
+ cat for cat in categories
110
+ if cat.upper() in response.upper()
111
+ ]
112
 
113
+ return detected
114
 
115
  except Exception as e:
116
+ logger.error(f"Chunk analysis error: {str(e)}")
 
 
117
  return []
118
 
119
  async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
 
 
 
 
120
  if not self.model or not self.tokenizer:
121
  await self.load_model(progress)
122
 
123
+ chunks = self._semantic_chunk_text(script)
124
+
125
+ # Concurrent chunk processing
126
+ tasks = [self.analyze_chunk(chunk) for chunk in chunks]
127
+ chunk_results = await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # Flatten and deduplicate results
130
+ identified_triggers = set(
131
+ trigger
132
+ for chunk_triggers in chunk_results
133
+ for trigger in chunk_triggers
134
+ )
135
 
136
+ return list(identified_triggers) or ["None"]
137
 
138
  async def analyze_content(
139
  script: str,
140
  progress: Optional[gr.Progress] = None
141
  ) -> Dict[str, Union[List[str], str]]:
 
 
 
 
142
  analyzer = ContentAnalyzer()
143
 
144
  try:
145
  triggers = await analyzer.analyze_script(script, progress)
 
 
 
146
 
147
  result = {
148
  "detected_triggers": triggers,
 
151
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
152
  }
153
 
 
154
  return result
155
 
156
  except Exception as e:
157
  logger.error(f"Analysis error: {str(e)}")
 
 
 
158
  return {
159
  "detected_triggers": ["Error occurred during analysis"],
160
  "confidence": "Error",
 
164
  }
165
 
166
  if __name__ == "__main__":
 
167
  iface = gr.Interface(
168
  fn=analyze_content,
169
  inputs=gr.Textbox(lines=8, label="Input Text"),