Kuberwastaken commited on
Commit
3b069b9
·
1 Parent(s): 9ee7507

First Attempt to shorten the model and making it workable on Spaces

Browse files
Files changed (1) hide show
  1. model/model.py +147 -190
model/model.py CHANGED
@@ -1,223 +1,180 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def analyze_script(script):
6
- # Starting the script analysis
7
- print("\n=== Starting Analysis ===")
8
- print(f"Time: {datetime.now()}") # Outputting the current timestamp
9
- print("Loading model and tokenizer...")
10
-
11
- try:
12
- # Load the tokenizer and model, selecting the appropriate device (CPU or CUDA)
13
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", use_fast=True)
14
- device = "cuda" if torch.cuda.is_available() else "cpu" # Use CUDA if available, else use CPU
15
- print(f"Using device: {device}")
16
 
17
- model = AutoModelForCausalLM.from_pretrained(
18
- "meta-llama/Llama-3.2-1B",
19
- torch_dtype=torch.float16 if device == "cuda" else torch.float32, # Use 16-bit precision for CUDA, 32-bit for CPU
20
- device_map="auto" # Automatically map model to available device
21
- )
22
- print("Model loaded successfully")
23
 
24
- # Define trigger categories with their descriptions
25
  trigger_categories = {
26
  "Violence": {
27
  "mapped_name": "Violence",
28
- "description": (
29
- "Any act involving physical force or aggression intended to cause harm, injury, or death to a person, animal, or object. "
30
- "Includes direct physical confrontations (e.g., fights, beatings, or assaults), implied violence (e.g., very graphical threats or descriptions of injuries), "
31
- "or large-scale events like wars, riots, or violent protests."
32
- )
33
  },
34
  "Death": {
35
  "mapped_name": "Death References",
36
- "description": (
37
- "Any mention, implication, or depiction of the loss of life, including direct deaths of characters, including mentions of deceased individuals, "
38
- "or abstract references to mortality (e.g., 'facing the end' or 'gone forever'). This also covers depictions of funerals, mourning, "
39
- "grieving, or any dialogue that centers around death, do not take metaphors into context that don't actually lead to death."
40
- )
41
  },
42
- "Substance Use": {
43
  "mapped_name": "Substance Use",
44
- "description": (
45
- "Any explicit or implied reference to the consumption, misuse, or abuse of drugs, alcohol, or other intoxicating substances. "
46
- "Includes scenes of drinking, smoking, or drug use, whether recreational or addictive. May also cover references to withdrawal symptoms, "
47
- "rehabilitation, or substance-related paraphernalia (e.g., needles, bottles, pipes)."
48
- )
49
  },
50
  "Gore": {
51
  "mapped_name": "Gore",
52
- "description": (
53
- "Extremely detailed and graphic depictions of highly severe physical injuries, mutilation, or extreme bodily harm, often accompanied by descriptions of heavy blood, exposed organs, "
54
- "or dismemberment. This includes war scenes with severe casualties, horror scenarios involving grotesque creatures, or medical procedures depicted with excessive detail."
55
- "only answer yes if you're completely certain."
56
- )
57
  },
58
- "Vomit": {
59
- "mapped_name": "Vomit",
60
- "description": (
61
- "Any explicit reference to vomiting, whether directly described, implied, or depicted. This includes detailed sounds, visual descriptions, mentions of nausea explicitly leading to vomiting, or any aftermath involving vomit."
62
- "Respond 'yes' only if the scene unambiguously and clearly involves vomiting, with no room for doubt."
63
- )
64
- },
65
- "Sexual Content": {
66
  "mapped_name": "Sexual Content",
67
- "description": (
68
- "Any depiction or mention of sexual activity, intimacy, or sexual behavior, ranging from implied scenes to explicit descriptions. "
69
- "This includes romantic encounters, physical descriptions of characters in a sexual context, sexual dialogue, or references to sexual themes (e.g., harassment, innuendos)."
70
- )
71
  },
72
- "Sexual Abuse": {
73
- "mapped_name": "Sexual Abuse",
74
- "description": (
75
- "Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. "
76
- "This includes incidents of sexual assault, molestation, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will or without their consent. "
77
- "It also covers discussions or depictions of the aftermath of such abuse, such as trauma, emotional distress, legal proceedings, or therapy. "
78
- "References to inappropriate sexual advances, groping, or any other form of sexual misconduct are also included, as well as the psychological and emotional impact on survivors. "
79
- "Scenes where individuals are placed in sexually compromising situations, even if not directly acted upon, may also fall under this category."
80
- "only answer yes if you're completely certain of it's presence."
81
- )
82
- },
83
- "Self-Harm": {
84
  "mapped_name": "Self-Harm",
85
- "description": (
86
- "Any mention or depiction of behaviors where an individual intentionally causes harm to themselves. This includes cutting, burning, or other forms of physical injury, "
87
- "as well as suicidal ideation, suicide attempts, or discussions of self-destructive thoughts and actions. References to scars, bruises, or other lasting signs of self-harm are also included."
88
- "only answer yes if you're completely certain."
89
- )
90
  },
91
- "Gun Use": {
92
- "mapped_name": "Gun Use",
93
- "description": (
94
- "Any explicit or implied mention of firearms being handled, fired, or used in a threatening manner. This includes scenes of gun violence, references to shootings, "
95
- "gun-related accidents, or the presence of firearms in a tense or dangerous context (e.g., holstered weapons during an argument)."
96
- )
97
- },
98
- "Animal Cruelty": {
99
- "mapped_name": "Animal Cruelty",
100
- "description": (
101
- "Any act of harm or abuse toward animals, whether intentional or accidental. This includes physical abuse (e.g., hitting, injuring, or killing animals), "
102
- "mental or emotional mistreatment (e.g., starvation, isolation), and scenes where animals are subjected to pain or suffering for human entertainment or experimentation."
103
- "Respond 'yes' only if the scene unambiguously and clearly involves Animal Cruelty, with no room for doubt"
104
- )
105
- },
106
- "Mental Health Issues": {
107
  "mapped_name": "Mental Health Issues",
108
- "description": (
109
- "Any reference to mental health struggles, disorders, or psychological distress. This includes mentions of depression, anxiety, PTSD, bipolar disorder, schizophrenia, "
110
- "or other conditions. Scenes depicting destructive coping mechanisms are also included."
111
- "like a character expressing feelings of worthlessness, hopelessness, or detachment from reality."
112
- )
113
  }
114
  }
115
 
116
- print("\nProcessing text...") # Output indicating the text is being processed
117
- chunk_size = 256 # Set the chunk size for text processing
118
- overlap = 15 # Overlap between chunks for context preservation
119
- script_chunks = [] # List to store script chunks
120
-
121
- # Split the script into smaller chunks
122
- for i in range(0, len(script), chunk_size - overlap):
123
- chunk = script[i:i + chunk_size]
124
- script_chunks.append(chunk)
125
-
126
- print(f"Split into {len(script_chunks)} chunks with {overlap} token overlap") # Inform about the chunking
127
-
128
- identified_triggers = {} # Dictionary to store the identified triggers
129
-
130
- # Process each chunk of the script
131
- for chunk_idx, chunk in enumerate(script_chunks, 1):
132
- print(f"\n--- Processing Chunk {chunk_idx}/{len(script_chunks)} ---")
133
- print(f"Chunk text (preview): {chunk[:50]}...") # Preview of the current chunk
134
 
135
- # Check each category for triggers
 
 
 
136
  for category, info in trigger_categories.items():
137
- mapped_name = info["mapped_name"]
138
- description = info["description"]
139
-
140
- print(f"\nAnalyzing for {mapped_name}...")
141
- prompt = f"""
142
- Check this text for any indication of {mapped_name} ({description}).
143
- Be sensitive to subtle references or implications, make sure the text is not metaphorical.
144
- Respond concisely with: YES, NO, or MAYBE.
145
- Text: {chunk}
146
- Answer:
147
- """
148
-
149
- print(f"Sending prompt to model...") # Indicate that prompt is being sent to the model
150
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) # Tokenize the prompt
151
- inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the chosen device
152
-
153
- with torch.no_grad(): # Disable gradient calculation for inference
154
- print("Generating response...") # Indicate that the model is generating a response
155
- outputs = model.generate(
156
- **inputs,
157
- max_new_tokens=10, # Limit response length
158
- do_sample=True, # Enable sampling for more diverse output
159
- temperature=0.5, # Control randomness of the output
160
- top_p=0.9, # Use nucleus sampling
161
- pad_token_id=tokenizer.eos_token_id # Pad token ID
162
- )
163
 
164
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper() # Decode and format the response
165
- first_word = response_text.split("\n")[-1].split()[0] if response_text else "NO" # Get the first word of the response
166
- print(f"Model response for {mapped_name}: {first_word}")
167
-
168
- # Update identified triggers based on model response
169
- if first_word == "YES":
170
- print(f"Detected {mapped_name} in this chunk!") # Trigger detected
171
- identified_triggers[mapped_name] = identified_triggers.get(mapped_name, 0) + 1
172
- elif first_word == "MAYBE":
173
- print(f"Possible {mapped_name} detected, marking for further review.") # Possible trigger detected
174
- identified_triggers[mapped_name] = identified_triggers.get(mapped_name, 0) + 0.5
175
- else:
176
- print(f"No {mapped_name} detected in this chunk.") # No trigger detected
177
-
178
- print("\n=== Analysis Complete ===") # Indicate that analysis is complete
179
- print("Final Results:")
180
- final_triggers = [] # List to store final triggers
181
-
182
- # Filter and output the final trigger results
183
- for mapped_name, count in identified_triggers.items():
184
- if count > 0.5:
185
- final_triggers.append(mapped_name)
186
- print(f"- {mapped_name}: found in {count} chunks")
187
-
188
- if not final_triggers:
189
- print("No triggers detected") # No triggers detected
190
- final_triggers = ["None"]
 
 
 
 
 
 
 
 
 
 
191
 
192
- print("\nReturning results...")
193
- return final_triggers # Return the list of detected triggers
194
 
195
- except Exception as e:
196
- # Handle errors and provide stack trace
197
- print(f"\nERROR OCCURRED: {str(e)}")
198
- print("Stack trace:")
199
- import traceback
200
- traceback.print_exc()
201
- return {"error": str(e)}
202
 
203
  def get_detailed_analysis(script):
204
- print("\n=== Starting Detailed Analysis ===")
205
- triggers = analyze_script(script) # Call the analyze_script function
206
-
207
- if isinstance(triggers, list) and triggers != ["None"]:
208
- result = {
209
- "detected_triggers": triggers,
210
- "confidence": "High - Content detected",
211
- "model": "Llama-3.2-1B",
212
- "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
213
- }
214
- else:
215
- result = {
216
- "detected_triggers": ["None"],
217
- "confidence": "High - No concerning content detected",
218
- "model": "Llama-3.2-1B",
219
- "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
220
- }
221
-
222
- print("\nFinal Result Dictionary:", result) # Output the final result dictionary
223
- return result
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  from datetime import datetime
4
+ import gc
5
+
6
+ class ContentAnalyzer:
7
+ def __init__(self):
8
+ self.model_name = "meta-llama/Llama-3.2-1B"
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ self.tokenizer = None
11
+ self.model = None
12
+
13
+ def load_model(self):
14
+ """Load model with memory optimization"""
15
+ try:
16
+ print("Loading tokenizer...")
17
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
18
+
19
+ print(f"Loading model on {self.device}...")
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ self.model_name,
22
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
23
+ low_cpu_mem_usage=True,
24
+ device_map="auto"
25
+ )
26
+ return True
27
+ except Exception as e:
28
+ print(f"Model loading error: {str(e)}")
29
+ return False
30
+
31
+ def cleanup(self):
32
+ """Clean up GPU memory"""
33
+ if self.device == "cuda":
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+ def analyze_chunk(self, chunk, category_info):
38
+ """Analyze a single chunk of text for a specific trigger"""
39
+ mapped_name = category_info["mapped_name"]
40
+ description = category_info["description"]
41
+
42
+ prompt = f"""Check this text for any indication of {mapped_name} ({description}).
43
+ Be sensitive to subtle references or implications, make sure the text is not metaphorical.
44
+ Respond concisely with: YES, NO, or MAYBE.
45
+ Text: {chunk}
46
+ Answer:"""
47
+
48
+ try:
49
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
50
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
51
+
52
+ with torch.no_grad():
53
+ outputs = self.model.generate(
54
+ **inputs,
55
+ max_new_tokens=10,
56
+ do_sample=True,
57
+ temperature=0.5,
58
+ top_p=0.9,
59
+ pad_token_id=self.tokenizer.eos_token_id
60
+ )
61
 
62
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper()
63
+ first_word = response.split("\n")[-1].split()[0] if response else "NO"
64
+
65
+ score = 1 if first_word == "YES" else 0.5 if first_word == "MAYBE" else 0
66
+ return score, first_word
67
+
68
+ except Exception as e:
69
+ print(f"Chunk analysis error: {str(e)}")
70
+ return 0, "NO"
 
 
71
 
72
+ def analyze_text(self, text):
73
+ """Main analysis function"""
74
+ if not self.load_model():
75
+ return {"error": "Model loading failed"}
 
 
76
 
77
+ # Original trigger categories
78
  trigger_categories = {
79
  "Violence": {
80
  "mapped_name": "Violence",
81
+ "description": "Any act involving physical force or aggression intended to cause harm, injury, or death."
 
 
 
 
82
  },
83
  "Death": {
84
  "mapped_name": "Death References",
85
+ "description": "Any mention, implication, or depiction of the loss of life, including direct deaths or abstract references to mortality."
 
 
 
 
86
  },
87
+ "Substance_Use": {
88
  "mapped_name": "Substance Use",
89
+ "description": "References to consumption, misuse, or abuse of drugs, alcohol, or other intoxicating substances."
 
 
 
 
90
  },
91
  "Gore": {
92
  "mapped_name": "Gore",
93
+ "description": "Graphic depictions of severe physical injuries, mutilation, or extreme bodily harm."
 
 
 
 
94
  },
95
+ "Sexual_Content": {
 
 
 
 
 
 
 
96
  "mapped_name": "Sexual Content",
97
+ "description": "Depictions or mentions of sexual activity, intimacy, or sexual behavior."
 
 
 
98
  },
99
+ "Self_Harm": {
 
 
 
 
 
 
 
 
 
 
 
100
  "mapped_name": "Self-Harm",
101
+ "description": "Behaviors where an individual intentionally causes harm to themselves."
 
 
 
 
102
  },
103
+ "Mental_Health": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  "mapped_name": "Mental Health Issues",
105
+ "description": "References to mental health struggles, disorders, or psychological distress."
 
 
 
 
106
  }
107
  }
108
 
109
+ try:
110
+ # Optimize chunk processing
111
+ chunk_size = 200 # Reduced chunk size for better memory management
112
+ overlap = 10
113
+ chunks = []
114
+
115
+ # Create chunks with overlap
116
+ for i in range(0, len(text), chunk_size - overlap):
117
+ chunk = text[i:i + chunk_size]
118
+ chunks.append(chunk)
 
 
 
 
 
 
 
 
119
 
120
+ trigger_scores = {}
121
+ trigger_occurrences = {}
122
+
123
+ # Initialize tracking dictionaries
124
  for category, info in trigger_categories.items():
125
+ trigger_scores[info["mapped_name"]] = 0
126
+ trigger_occurrences[info["mapped_name"]] = []
127
+
128
+ # Process all chunks for all categories
129
+ for chunk_idx, chunk in enumerate(chunks):
130
+ print(f"\nProcessing chunk {chunk_idx + 1}/{len(chunks)}")
131
+ chunk_triggers = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ for category, info in trigger_categories.items():
134
+ score, response = self.analyze_chunk(chunk, info)
135
+
136
+ if score > 0:
137
+ mapped_name = info["mapped_name"]
138
+ trigger_scores[mapped_name] += score
139
+ trigger_occurrences[mapped_name].append({
140
+ 'chunk_idx': chunk_idx,
141
+ 'response': response,
142
+ 'score': score
143
+ })
144
+ print(f"Found {mapped_name} in chunk {chunk_idx + 1} (Response: {response})")
145
+
146
+ # Cleanup after processing each chunk
147
+ if self.device == "cuda":
148
+ self.cleanup()
149
+
150
+ # Collect all triggers that meet the threshold
151
+ detected_triggers = []
152
+ for name, score in trigger_scores.items():
153
+ if score >= 0.5: # Threshold for considering a trigger as detected
154
+ occurrences = len(trigger_occurrences[name])
155
+ detected_triggers.append(name)
156
+ print(f"\nTrigger '{name}' detected in {occurrences} chunks with total score {score}")
157
+
158
+ result = {
159
+ "detected_triggers": detected_triggers if detected_triggers else ["None"],
160
+ "confidence": "High - Content detected" if detected_triggers else "High - No concerning content detected",
161
+ "model": self.model_name,
162
+ "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
163
+ "trigger_details": {
164
+ name: {
165
+ "total_score": trigger_scores[name],
166
+ "occurrences": trigger_occurrences[name]
167
+ } for name in detected_triggers if name != "None"
168
+ }
169
+ }
170
 
171
+ return result
 
172
 
173
+ except Exception as e:
174
+ return {"error": str(e)}
175
+ finally:
176
+ self.cleanup()
 
 
 
177
 
178
  def get_detailed_analysis(script):
179
+ analyzer = ContentAnalyzer()
180
+ return analyzer.analyze_text(script)