frankjosh commited on
Commit
5c717fb
·
verified ·
1 Parent(s): 0c9f09c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -289
app.py CHANGED
@@ -7,11 +7,9 @@ import numpy as np
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from transformers import AutoTokenizer, AutoModel
9
  import torch
10
- from tqdm import tqdm
11
- from datasets import load_dataset
12
  from datetime import datetime
13
  from typing import List, Dict, Any
14
- from torch.utils.data import DataLoader, Dataset
15
  from functools import partial
16
 
17
  # Configure GPU if available
@@ -20,325 +18,149 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  # Initialize session state
21
  if 'history' not in st.session_state:
22
  st.session_state.history = []
 
23
  if 'feedback' not in st.session_state:
24
  st.session_state.feedback = {}
25
 
26
  # Define subset size
27
- SUBSET_SIZE = 1000 # Starting with 500 items for quick testing
28
-
29
- class TextDataset(Dataset):
30
- def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
31
- self.texts = texts
32
- self.tokenizer = tokenizer
33
- self.max_length = max_length
34
-
35
- def __len__(self):
36
- return len(self.texts)
37
-
38
- def __getitem__(self, idx):
39
- return self.tokenizer(
40
- self.texts[idx],
41
- padding='max_length',
42
- truncation=True,
43
- max_length=self.max_length,
44
- return_tensors="pt"
45
- )
46
 
47
- def generate_case_study(row: Dict[str, Any]) -> str:
48
- """Generate a detailed case study for a repository using available metadata"""
49
- # Extract relevant information from the row
50
- summary = row.get('summary', '').strip()
51
- docstring = row.get('docstring', '').strip()
52
- repo_name = row.get('repo', '').strip()
53
-
54
- # Generate a more detailed overview using available information
55
- overview = summary if summary else "This repository provides a software implementation"
56
- if docstring:
57
- # Extract the first paragraph of the docstring for additional context
58
- first_para = docstring.split('\n\n')[0].strip()
59
- overview = f"{overview}. {first_para}"
60
-
61
- # Analyze the repository path to infer technology stack
62
- path_components = row.get('path', '').lower().split('/')
63
- tech_stack = []
64
-
65
- # Common technology indicators in paths
66
- if any('python' in comp for comp in path_components):
67
- tech_stack.append("Python")
68
- if any('tensorflow' in comp or 'tf' in comp for comp in path_components):
69
- tech_stack.append("TensorFlow")
70
- if any('pytorch' in comp for comp in path_components):
71
- tech_stack.append("PyTorch")
72
- if any('react' in comp for comp in path_components):
73
- tech_stack.append("React")
74
-
75
- tech_stack_str = ", ".join(tech_stack) if tech_stack else "various technologies"
76
-
77
- case_study = f"""
78
- ### Overview
79
- {overview}
80
-
81
- ### Technical Implementation
82
- This project is built using {tech_stack_str}. The implementation focuses on providing a robust and maintainable solution for {summary.lower() if summary else 'the specified requirements'}.
83
-
84
- ### Key Features
85
- - Primary functionality: {summary if summary else 'Implementation of core project requirements'}
86
- - Complete documentation and code examples
87
- - Well-structured implementation following best practices
88
- - Modular design for easy integration and customization
89
-
90
- ### Use Cases
91
- This repository is particularly valuable for:
92
- - Developers implementing similar functionality in their projects
93
- - Teams looking for reference implementations and best practices
94
- - Projects requiring similar technical capabilities
95
- - Learning and educational purposes in related technical domains
96
-
97
- ### Integration Considerations
98
- The repository can be integrated into existing projects, with consideration for:
99
- - Compatibility with existing technology stacks
100
- - Required dependencies and prerequisites
101
- - Potential customization needs
102
- - Performance and scalability requirements
103
  """
104
- return case_study
105
-
106
- def display_recommendations(recommendations: pd.DataFrame):
107
- """Display recommendations in a list format with all details"""
108
- st.markdown("### 🎯 Top Recommendations")
109
-
110
- # Create a list of recommendations
111
- for idx, row in recommendations.iterrows():
112
- with st.container():
113
- # Header with repository name and match score
114
- col1, col2 = st.columns([3, 1])
115
- with col1:
116
- st.markdown(f"### {idx + 1}. {row['repo']}")
117
- with col2:
118
- st.metric("Match Score", f"{row['similarity']:.2%}")
119
-
120
- # Repository details
121
- st.markdown(f"**URL:** [View Repository]({row['url']})")
122
- st.markdown(f"**Path:** `{row['path']}`")
123
-
124
- # Feedback buttons
125
- col1, col2, col3 = st.columns([1, 1, 4])
126
- with col1:
127
- if st.button("👍", key=f"like_{idx}"):
128
- st.session_state.feedback[row['repo']] = st.session_state.feedback.get(row['repo'], {'likes': 0, 'dislikes': 0})
129
- st.session_state.feedback[row['repo']]['likes'] += 1
130
- st.success("Thanks for your feedback!")
131
- with col2:
132
- if st.button("👎", key=f"dislike_{idx}"):
133
- st.session_state.feedback[row['repo']] = st.session_state.feedback.get(row['repo'], {'likes': 0, 'dislikes': 0})
134
- st.session_state.feedback[row['repo']]['dislikes'] += 1
135
- st.success("Thanks for your feedback!")
136
-
137
- # Documentation and case study in tabs
138
- tab1, tab2 = st.tabs(["📚 Documentation", "📑 Case Study"])
139
- with tab1:
140
- if row['docstring']:
141
- st.markdown(row['docstring'])
142
- else:
143
- st.info("No documentation available")
144
-
145
- with tab2:
146
- st.markdown(generate_case_study(row))
147
-
148
- st.markdown("---")
149
 
150
  @st.cache_resource
151
- def load_data_and_model():
152
- """Load the dataset and model with optimized memory usage"""
153
- try:
154
- # Load dataset
155
- dataset = load_dataset("frankjosh/filtered_dataset")
156
- data = pd.DataFrame(dataset['train'])
157
-
158
- # Take a random subset
159
- data = data.sample(n=min(SUBSET_SIZE, len(data)), random_state=42).reset_index(drop=True)
160
-
161
- # Combine text fields
162
- data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
163
-
164
- # Load model and tokenizer
165
- model_name = "Salesforce/codet5-small"
166
- tokenizer = AutoTokenizer.from_pretrained(model_name)
167
- model = AutoModel.from_pretrained(model_name)
168
-
169
- if torch.cuda.is_available():
170
- model = model.to(device)
171
-
172
- model.eval()
173
- return data, tokenizer, model
174
-
175
- except Exception as e:
176
- st.error(f"Error in initialization: {str(e)}")
177
- st.stop()
178
-
179
- def collate_fn(batch, pad_token_id):
180
- max_length = max(inputs['input_ids'].shape[1] for inputs in batch)
181
- input_ids = []
182
- attention_mask = []
183
-
184
- for inputs in batch:
185
- input_ids.append(torch.nn.functional.pad(
186
- inputs['input_ids'].squeeze(),
187
- (0, max_length - inputs['input_ids'].shape[1]),
188
- value=pad_token_id
189
- ))
190
- attention_mask.append(torch.nn.functional.pad(
191
- inputs['attention_mask'].squeeze(),
192
- (0, max_length - inputs['attention_mask'].shape[1]),
193
- value=0
194
- ))
195
-
196
- return {
197
- 'input_ids': torch.stack(input_ids),
198
- 'attention_mask': torch.stack(attention_mask)
199
- }
200
 
201
- def generate_embeddings_batch(model, batch, device):
202
- """Generate embeddings for a batch of inputs"""
203
- with torch.no_grad():
204
- batch = {k: v.to(device) for k, v in batch.items()}
205
- outputs = model.encoder(**batch)
206
- embeddings = outputs.last_hidden_state.mean(dim=1)
207
- return embeddings.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- def precompute_embeddings(data: pd.DataFrame, model, tokenizer, batch_size: int = 16):
210
- """Precompute embeddings with batching and progress tracking"""
211
  dataset = TextDataset(data['text'].tolist(), tokenizer)
212
  dataloader = DataLoader(
213
- dataset,
214
- batch_size=batch_size,
215
- shuffle=False,
216
- collate_fn=partial(collate_fn, pad_token_id=tokenizer.pad_token_id),
217
- num_workers=2,
218
- pin_memory=True
219
  )
220
-
221
  embeddings = []
222
- total_batches = len(dataloader)
223
-
224
- # Create a progress bar
225
- progress_bar = st.progress(0)
226
- status_text = st.empty()
227
-
228
- start_time = datetime.now()
229
-
230
- for i, batch in enumerate(dataloader):
231
- # Generate embeddings for batch
232
  batch_embeddings = generate_embeddings_batch(model, batch, device)
233
  embeddings.extend(batch_embeddings)
234
-
235
- # Update progress
236
- progress = (i + 1) / total_batches
237
- progress_bar.progress(progress)
238
-
239
- # Calculate and display ETA
240
- elapsed_time = (datetime.now() - start_time).total_seconds()
241
- eta = (elapsed_time / (i + 1)) * (total_batches - (i + 1))
242
- status_text.text(f"Processing batch {i+1}/{total_batches}. ETA: {int(eta)} seconds")
243
-
244
- progress_bar.empty()
245
- status_text.empty()
246
-
247
- # Add embeddings to dataframe
248
  data['embedding'] = embeddings
249
  return data
250
 
251
  @torch.no_grad()
252
  def generate_query_embedding(model, tokenizer, query: str) -> np.ndarray:
253
- """Generate embedding for a single query"""
 
 
254
  inputs = tokenizer(
255
- query,
256
- return_tensors="pt",
257
- padding=True,
258
- truncation=True,
259
- max_length=512
260
  ).to(device)
261
-
262
  outputs = model.encoder(**inputs)
263
- embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
264
- return embedding.squeeze()
265
 
266
- def find_similar_repos(query_embedding: np.ndarray, data: pd.DataFrame, top_n: int = 5) -> pd.DataFrame:
267
- """Find similar repositories using vectorized operations"""
 
 
268
  similarities = cosine_similarity([query_embedding], np.stack(data['embedding'].values))[0]
269
  data['similarity'] = similarities
270
  return data.nlargest(top_n, 'similarity')
271
 
272
- # Load resources
273
- data, tokenizer, model = load_data_and_model()
274
-
275
- # Add info about subset size
276
- st.info(f"Running with a subset of {SUBSET_SIZE} repositories for testing purposes.")
277
-
278
- # Precompute embeddings for the subset
279
- data = precompute_embeddings(data, model, tokenizer)
 
280
 
281
- # Main App Interface
282
  st.title("Repository Recommender System 🚀")
283
- st.caption("Testing Version - Running on subset of data")
284
 
285
- # Main interface
 
 
 
 
 
286
  user_query = st.text_area(
287
- "Describe your project:",
288
- height=150,
289
- placeholder="Example: I need a machine learning project for customer churn prediction..."
290
  )
291
 
292
- # Search button and filters
293
- col1, col2 = st.columns([2, 1])
294
- with col1:
295
- search_button = st.button("🔍 Search Repositories", type="primary")
296
- with col2:
297
- top_n = st.selectbox("Number of results:", [3, 5, 10], index=1)
298
-
299
- if search_button and user_query.strip():
300
- with st.spinner("Finding relevant repositories..."):
301
- # Generate query embedding and get recommendations
302
- query_embedding = generate_query_embedding(model, tokenizer, user_query)
303
- recommendations = find_similar_repos(query_embedding, data, top_n)
304
-
305
- # Save to history
306
- st.session_state.history.append({
307
- 'query': user_query,
308
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
309
- 'results': recommendations['repo'].tolist()
310
- })
311
-
312
- # Display recommendations using the new function
313
- display_recommendations(recommendations)
314
-
315
- # Sidebar for History and Stats
316
- with st.sidebar:
317
- st.header("📊 Search History")
318
- if st.session_state.history:
319
- for idx, item in enumerate(reversed(st.session_state.history[-5:])):
320
- st.markdown(f"**Search {len(st.session_state.history)-idx}**")
321
- st.markdown(f"Query: _{item['query'][:30]}..._")
322
- st.caption(f"Time: {item['timestamp']}")
323
- st.caption(f"Results: {len(item['results'])} repositories")
324
- if st.button("Rerun this search", key=f"rerun_{idx}"):
325
- st.session_state.rerun_query = item['query']
326
- st.markdown("---")
327
  else:
328
- st.write("No search history yet")
329
-
330
- st.header("📈 Usage Statistics")
331
- st.write(f"Total Searches: {len(st.session_state.history)}")
332
- if st.session_state.feedback:
333
- feedback_df = pd.DataFrame(st.session_state.feedback).T
334
- feedback_df['Total'] = feedback_df['likes'] + feedback_df['dislikes']
335
- st.bar_chart(feedback_df[['likes', 'dislikes']])
336
-
337
- # Footer
338
- st.markdown("---")
339
- st.markdown(
340
- """
341
- Made with 🤖 using CodeT5 and Streamlit |
342
-
343
- """
344
- )
 
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from transformers import AutoTokenizer, AutoModel
9
  import torch
10
+ from torch.utils.data import DataLoader, Dataset
 
11
  from datetime import datetime
12
  from typing import List, Dict, Any
 
13
  from functools import partial
14
 
15
  # Configure GPU if available
 
18
  # Initialize session state
19
  if 'history' not in st.session_state:
20
  st.session_state.history = []
21
+
22
  if 'feedback' not in st.session_state:
23
  st.session_state.feedback = {}
24
 
25
  # Define subset size
26
+ SUBSET_SIZE = 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Caching key resources: Model, Tokenizer, and Precomputed Embeddings
29
+ @st.cache_resource
30
+ def load_model_and_tokenizer():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
+ Load the pre-trained model and tokenizer using Hugging Face Transformers.
33
+ Cached to ensure it loads only once.
34
+ """
35
+ model_name = "Salesforce/codet5-small"
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModel.from_pretrained(model_name).to(device)
38
+ model.eval()
39
+ return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @st.cache_resource
42
+ def load_data():
43
+ """
44
+ Load and sample the dataset from Hugging Face.
45
+ Returns a DataFrame with a fixed subset of repositories.
46
+ """
47
+ dataset = load_dataset("frankjosh/filtered_dataset")
48
+ data = pd.DataFrame(dataset['train'])
49
+ data = data.sample(n=min(SUBSET_SIZE, len(data)), random_state=42).reset_index(drop=True)
50
+ return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ @st.cache_resource
53
+ def precompute_embeddings(data: pd.DataFrame, tokenizer, model, batch_size=16):
54
+ """
55
+ Precompute embeddings for repository metadata to optimize query performance.
56
+ """
57
+ class TextDataset(Dataset):
58
+ def __init__(self, texts: List[str], tokenizer, max_length=512):
59
+ self.texts = texts
60
+ self.tokenizer = tokenizer
61
+ self.max_length = max_length
62
+
63
+ def __len__(self):
64
+ return len(self.texts)
65
+
66
+ def __getitem__(self, idx):
67
+ return self.tokenizer(
68
+ self.texts[idx],
69
+ padding='max_length',
70
+ truncation=True,
71
+ max_length=self.max_length,
72
+ return_tensors="pt"
73
+ )
74
+
75
+ def collate_fn(batch, pad_token_id):
76
+ max_length = max(inputs['input_ids'].shape[1] for inputs in batch)
77
+ input_ids, attention_mask = [], []
78
+ for inputs in batch:
79
+ input_ids.append(torch.nn.functional.pad(
80
+ inputs['input_ids'].squeeze(),
81
+ (0, max_length - inputs['input_ids'].shape[1]),
82
+ value=pad_token_id
83
+ ))
84
+ attention_mask.append(torch.nn.functional.pad(
85
+ inputs['attention_mask'].squeeze(),
86
+ (0, max_length - inputs['attention_mask'].shape[1]),
87
+ value=0
88
+ ))
89
+ return {
90
+ 'input_ids': torch.stack(input_ids),
91
+ 'attention_mask': torch.stack(attention_mask)
92
+ }
93
+
94
+ def generate_embeddings_batch(model, batch, device):
95
+ with torch.no_grad():
96
+ batch = {k: v.to(device) for k, v in batch.items()}
97
+ outputs = model.encoder(**batch)
98
+ return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
99
 
 
 
100
  dataset = TextDataset(data['text'].tolist(), tokenizer)
101
  dataloader = DataLoader(
102
+ dataset, batch_size=batch_size, shuffle=False,
103
+ collate_fn=partial(collate_fn, pad_token_id=tokenizer.pad_token_id)
 
 
 
 
104
  )
105
+
106
  embeddings = []
107
+ for batch in dataloader:
 
 
 
 
 
 
 
 
 
108
  batch_embeddings = generate_embeddings_batch(model, batch, device)
109
  embeddings.extend(batch_embeddings)
110
+
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  data['embedding'] = embeddings
112
  return data
113
 
114
  @torch.no_grad()
115
  def generate_query_embedding(model, tokenizer, query: str) -> np.ndarray:
116
+ """
117
+ Generate embedding for a user query using the pre-trained model.
118
+ """
119
  inputs = tokenizer(
120
+ query, return_tensors="pt", padding=True,
121
+ truncation=True, max_length=512
 
 
 
122
  ).to(device)
 
123
  outputs = model.encoder(**inputs)
124
+ return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
 
125
 
126
+ def find_similar_repos(query_embedding: np.ndarray, data: pd.DataFrame, top_n=5) -> pd.DataFrame:
127
+ """
128
+ Compute cosine similarity and return the top N most similar repositories.
129
+ """
130
  similarities = cosine_similarity([query_embedding], np.stack(data['embedding'].values))[0]
131
  data['similarity'] = similarities
132
  return data.nlargest(top_n, 'similarity')
133
 
134
+ def display_recommendations(recommendations: pd.DataFrame):
135
+ """
136
+ Display the recommended repositories in the Streamlit app interface.
137
+ """
138
+ st.markdown("### 🎯 Top Recommendations")
139
+ for idx, row in recommendations.iterrows():
140
+ st.markdown(f"### {idx + 1}. {row['repo']}")
141
+ st.metric("Match Score", f"{row['similarity']:.2%}")
142
+ st.markdown(f"[View Repository]({row['url']})")
143
 
144
+ # Main workflow
145
  st.title("Repository Recommender System 🚀")
146
+ st.caption("Find repositories based on your project description.")
147
 
148
+ # Load resources
149
+ tokenizer, model = load_model_and_tokenizer()
150
+ data = load_data()
151
+ data = precompute_embeddings(data, tokenizer, model)
152
+
153
+ # User input
154
  user_query = st.text_area(
155
+ "Describe your project:", height=150,
156
+ placeholder="Example: A machine learning project for customer churn prediction..."
 
157
  )
158
 
159
+ if st.button("🔍 Search Repositories"):
160
+ if user_query.strip():
161
+ with st.spinner("Finding relevant repositories..."):
162
+ query_embedding = generate_query_embedding(model, tokenizer, user_query)
163
+ recommendations = find_similar_repos(query_embedding, data)
164
+ display_recommendations(recommendations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  else:
166
+ st.error("Please provide a project description.")