davanstrien HF Staff commited on
Commit
72c9292
·
1 Parent(s): 1ef69ae

refactor model data loading to handle optional 'param_count' column and improve logging

Browse files
Files changed (1) hide show
  1. main.py +27 -13
main.py CHANGED
@@ -193,24 +193,38 @@ def setup_database():
193
  )
194
 
195
  # Load model data
196
- model_df = pl.scan_parquet(
197
  "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
198
  )
199
- model_row_count = model_df.select(pl.len()).collect().item()
200
  logger.info(f"Row count of new model data: {model_row_count}")
201
 
202
  if model_collection.count() < model_row_count:
203
- model_df = model_df.select(
204
- [
205
- "modelId",
206
- "summary",
207
- "likes",
208
- "downloads",
209
- "last_modified",
210
- "param_count",
211
- ]
212
- )
213
- model_df = model_df.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  total_rows = len(model_df)
215
 
216
  for i in range(0, total_rows, BATCH_SIZE):
 
193
  )
194
 
195
  # Load model data
196
+ model_lazy_df = pl.scan_parquet(
197
  "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
198
  )
199
+ model_row_count = model_lazy_df.select(pl.len()).collect().item()
200
  logger.info(f"Row count of new model data: {model_row_count}")
201
 
202
  if model_collection.count() < model_row_count:
203
+ schema = model_lazy_df.schema
204
+ select_columns = [
205
+ "modelId",
206
+ "summary",
207
+ "likes",
208
+ "downloads",
209
+ "last_modified",
210
+ ]
211
+ if "param_count" in schema:
212
+ logger.info("Found 'param_count' column in model data schema.")
213
+ select_columns.append("param_count")
214
+ else:
215
+ logger.warning(
216
+ "'param_count' column not found in model data schema. Will add it with null values."
217
+ )
218
+
219
+ # Select specified columns and then collect
220
+ model_df = model_lazy_df.select(select_columns).collect()
221
+
222
+ # If param_count was not in the original schema, add it now to the collected DataFrame
223
+ if "param_count" not in model_df.columns:
224
+ model_df = model_df.with_columns(
225
+ pl.lit(None).cast(pl.Int64).alias("param_count")
226
+ )
227
+
228
  total_rows = len(model_df)
229
 
230
  for i in range(0, total_rows, BATCH_SIZE):