1 | """ |
---|
2 | Define the csv model base classe |
---|
3 | """ |
---|
4 | import copy |
---|
5 | |
---|
6 | import csv |
---|
7 | from django.db.models.base import Model |
---|
8 | from adaptor.fields import Field, IgnoredField, ComposedKeyField, XMLRootField |
---|
9 | from adaptor.exceptions import ForeignKeyFieldError, FieldValueMissing |
---|
10 | |
---|
11 | |
---|
12 | class ImproperlyConfigured(Exception): |
---|
13 | """ |
---|
14 | Raised if a missing config value is detected |
---|
15 | """ |
---|
16 | |
---|
17 | |
---|
18 | class CsvException(Exception): |
---|
19 | """ |
---|
20 | Raised if a problem in the file is detected |
---|
21 | """ |
---|
22 | |
---|
23 | |
---|
24 | class CsvDataException(CsvException): |
---|
25 | """ |
---|
26 | Raised if a data does not match the expectations |
---|
27 | """ |
---|
28 | def __init__(self, line, error=None, field_error=None): |
---|
29 | self.line = line + 1 |
---|
30 | self.error = error |
---|
31 | self.field_error = field_error |
---|
32 | err_msg = self.error if self.error else self.field_error |
---|
33 | super(CsvDataException, self).__init__(u"Line %d: %s" % (self.line, err_msg)) |
---|
34 | |
---|
35 | |
---|
36 | class CsvFieldDataException(CsvDataException): |
---|
37 | def __init__(self, line, field_error, model, value): |
---|
38 | self.model = model |
---|
39 | self.value = value |
---|
40 | super(CsvFieldDataException, self).__init__(line, field_error=field_error) |
---|
41 | |
---|
42 | |
---|
43 | class SkipRow(Exception): |
---|
44 | pass |
---|
45 | |
---|
46 | |
---|
47 | class BaseModel(object): |
---|
48 | def __init__(self, data, delimiter=None): |
---|
49 | self.cls = self.__class__ |
---|
50 | self.attrs = self.get_fields() |
---|
51 | self.errors = [] |
---|
52 | self.dont_raise_exception = hasattr(self.cls, "Meta") and hasattr(self.cls.Meta, "raise_exception") and not self.cls.Meta.raise_exception |
---|
53 | |
---|
54 | def is_valid(self): |
---|
55 | return len(self.errors) == 0 |
---|
56 | |
---|
57 | @classmethod |
---|
58 | def get_fields(cls): |
---|
59 | all_cls_dict = {} |
---|
60 | all_cls_dict.update(cls.__dict__) |
---|
61 | for klass in cls.__bases__: |
---|
62 | all_cls_dict.update(klass.__dict__) |
---|
63 | |
---|
64 | # Add a copy the attribute to not have interference between differente instance |
---|
65 | # of a same class |
---|
66 | attributes = [(attr, copy.copy(all_cls_dict[attr])) for attr in all_cls_dict |
---|
67 | if isinstance(all_cls_dict[attr], |
---|
68 | Field)] |
---|
69 | for fieldname, field in attributes: |
---|
70 | field.fieldname = fieldname |
---|
71 | |
---|
72 | sorted_field = sorted(attributes, key=lambda attrs: attrs[1].position) |
---|
73 | return sorted_field |
---|
74 | |
---|
75 | @classmethod |
---|
76 | def get_data_fields(cls): |
---|
77 | return [fieldname for (fieldname, field) in cls.get_fields() if fieldname not in getattr(cls, "_exclude_data_fields", [])] |
---|
78 | |
---|
79 | def as_dict(self): |
---|
80 | return dict((field, getattr(self, field)) for field in self.get_data_fields()) |
---|
81 | |
---|
82 | def get_value(self, attr_name, field, value): |
---|
83 | self.__dict__[attr_name] = field.get_prep_value(value) |
---|
84 | self.field_matching_name = field.__dict__.get("match", attr_name) |
---|
85 | return field.get_prep_value(value) |
---|
86 | |
---|
87 | def update_object(self, dict_values, object, update_dict): |
---|
88 | new_dict_values = {} |
---|
89 | if 'update' in update_dict: |
---|
90 | # Update the object for the value un update_dict['update'] |
---|
91 | # only |
---|
92 | for field_name in update_dict['update']: |
---|
93 | new_dict_values[field_name] = dict_values[field_name] |
---|
94 | else: |
---|
95 | new_dict_values = dict_values |
---|
96 | for field_name in new_dict_values: |
---|
97 | setattr(object, field_name, new_dict_values[field_name]) |
---|
98 | object.save() |
---|
99 | |
---|
100 | def base_create_model(self, model, **dict_values): |
---|
101 | object = None |
---|
102 | if self.cls.has_update_method(): |
---|
103 | keys = None |
---|
104 | update_dict = self.cls.Meta.update |
---|
105 | try: |
---|
106 | keys = update_dict['keys'] |
---|
107 | except KeyError: |
---|
108 | raise ImproperlyConfigured("The update dict should contains a keys value") |
---|
109 | filter_values = {} |
---|
110 | for key in keys: |
---|
111 | filter_values.update({key: dict_values[key]}) |
---|
112 | object = None |
---|
113 | try: |
---|
114 | object = model.objects.get(**filter_values) |
---|
115 | except model.DoesNotExist: |
---|
116 | object = model.objects.create(**dict_values) |
---|
117 | except model.MultipleObjectsReturned: |
---|
118 | raise ImproperlyConfigured( |
---|
119 | "Multiple values returned for the update key %s.\ |
---|
120 | Keys provide are not unique" % filter_values) |
---|
121 | else: |
---|
122 | self.update_object(dict_values, object, update_dict) |
---|
123 | else: |
---|
124 | object = model.objects.create(**dict_values) |
---|
125 | self.object = object |
---|
126 | |
---|
127 | def get_object(self): |
---|
128 | if self.cls.is_db_model(): |
---|
129 | return self.object |
---|
130 | return None |
---|
131 | |
---|
132 | def create_model_instance(self, values): |
---|
133 | model = self.cls.Meta.dbModel |
---|
134 | if self.multiple_creation_field: |
---|
135 | if self.multiple_creation_field: |
---|
136 | multiple_values = values.pop(self.multiple_creation_field) |
---|
137 | for value in multiple_values: |
---|
138 | dict_values = values.copy() |
---|
139 | dict_values[self.multiple_creation_field] = value |
---|
140 | self.base_create_model(model, **dict_values) |
---|
141 | |
---|
142 | else: |
---|
143 | self.base_create_model(model, **values) |
---|
144 | |
---|
145 | def set_values(self, values_dict, fields_name, values): |
---|
146 | if isinstance(fields_name, list): |
---|
147 | for field_name in fields_name: |
---|
148 | values_dict[field_name] = values |
---|
149 | else: |
---|
150 | values_dict[fields_name] = values |
---|
151 | |
---|
152 | def construct_obj_from_model(self, object): |
---|
153 | for field_name, field in self.get_fields(): |
---|
154 | setattr(self, |
---|
155 | field_name, |
---|
156 | getattr(object, |
---|
157 | # If match attribute is defined, use the match name, |
---|
158 | # else use the field name |
---|
159 | field.__dict__.get("match", field_name), None)) |
---|
160 | return self |
---|
161 | |
---|
162 | def export(self): |
---|
163 | line = u"" |
---|
164 | for field_name, field in self.get_fields(): |
---|
165 | line += unicode(getattr(self, field_name)) |
---|
166 | line += self.delimiter |
---|
167 | return line.rstrip(self.delimiter) # remove the extra delimiter |
---|
168 | |
---|
169 | @classmethod |
---|
170 | def is_db_model(cls): |
---|
171 | return hasattr(cls, "Meta") and hasattr(cls.Meta, "dbModel") and cls.Meta.dbModel |
---|
172 | |
---|
173 | @classmethod |
---|
174 | def has_class_delimiter(cls): |
---|
175 | return hasattr(cls, "Meta") and hasattr(cls.Meta, "delimiter") |
---|
176 | |
---|
177 | @classmethod |
---|
178 | def has_header(cls): |
---|
179 | return hasattr(cls, "Meta") and hasattr(cls.Meta, "has_header") and cls.Meta.has_header |
---|
180 | |
---|
181 | @classmethod |
---|
182 | def has_update_method(cls): |
---|
183 | has_update = hasattr(cls, "Meta") and hasattr(cls.Meta, "update") |
---|
184 | if has_update and not cls.is_db_model(): |
---|
185 | raise ImproperlyConfigured("You should define a model when using the update option") |
---|
186 | return has_update |
---|
187 | |
---|
188 | @classmethod |
---|
189 | def silent_failure(cls): |
---|
190 | if not hasattr(cls, "Meta") or not hasattr(cls.Meta, "silent_failure"): |
---|
191 | return False |
---|
192 | return cls.Meta.silent_failure |
---|
193 | |
---|
194 | @classmethod |
---|
195 | def import_data(cls, data, extra_fields=[]): |
---|
196 | importer = cls.get_importer(extra_fields) |
---|
197 | return importer.import_data(data) |
---|
198 | |
---|
199 | @classmethod |
---|
200 | def import_from_filename(cls, filename, extra_fields=[]): |
---|
201 | importer = cls.get_importer(extra_fields=extra_fields) |
---|
202 | return importer.import_from_filename(filename) |
---|
203 | |
---|
204 | @classmethod |
---|
205 | def import_from_file(cls, file, extra_fields=[]): |
---|
206 | importer = cls.get_importer(extra_fields=extra_fields) |
---|
207 | return importer.import_from_file(file) |
---|
208 | |
---|
209 | |
---|
210 | class CsvModel(BaseModel): |
---|
211 | |
---|
212 | def __init__(self, data, delimiter=None): |
---|
213 | super(CsvModel, self).__init__(data) |
---|
214 | self.delimiter = None |
---|
215 | if delimiter: |
---|
216 | self.delimiter = delimiter |
---|
217 | elif self.has_class_delimiter(): |
---|
218 | self. delimiter = self.cls.Meta.delimiter |
---|
219 | if not isinstance(data, Model): |
---|
220 | self.construct_obj_from_data(data) |
---|
221 | else: |
---|
222 | self.construct_obj_from_model(data) |
---|
223 | |
---|
224 | |
---|
225 | def validate(self): |
---|
226 | if len(self.attrs) == 0: |
---|
227 | raise ImproperlyConfigured("No field defined. Should have at least one field in the model.") |
---|
228 | if not self.cls.has_class_delimiter() and not getattr(self, "delimiter", False) and len(self.attrs) > 1: |
---|
229 | raise ImproperlyConfigured( |
---|
230 | "More than a single field and no delimiter defined. You should define a delimiter.") |
---|
231 | |
---|
232 | @classmethod |
---|
233 | def get_importer(cls, extra_fields=[]): |
---|
234 | return CsvImporter(csvModel=cls, extra_fields=extra_fields) |
---|
235 | |
---|
236 | def construct_obj_from_data(self, data): |
---|
237 | self.validate() |
---|
238 | values = {} |
---|
239 | silent_failure = self.cls.silent_failure() |
---|
240 | self.multiple_creation_field = None |
---|
241 | composed_fields = [] |
---|
242 | index_offset = 0 |
---|
243 | data_offset = 0 |
---|
244 | for position, (attr_name, field) in enumerate(self.attrs): |
---|
245 | field.position = position |
---|
246 | if isinstance(field, ComposedKeyField): |
---|
247 | composed_fields.append(field) |
---|
248 | index_offset += 1 |
---|
249 | continue |
---|
250 | if self.cls.has_class_delimiter() or self.delimiter: |
---|
251 | value = data.pop(position - index_offset - data_offset) |
---|
252 | data_offset += 1 |
---|
253 | else: |
---|
254 | value = data.pop(0) |
---|
255 | try: |
---|
256 | if isinstance(field, IgnoredField): |
---|
257 | continue |
---|
258 | if hasattr(field, 'has_multiple') and field.has_multiple: |
---|
259 | remaining_data = [value] + data[:] # value should be re-added |
---|
260 | # as it has been pop before |
---|
261 | multiple_values = [] |
---|
262 | for data in remaining_data: |
---|
263 | multiple_values.append(self.get_value(attr_name, field, data)) |
---|
264 | self.set_values(values, self.field_matching_name, multiple_values) |
---|
265 | self.multiple_creation_field = self.field_matching_name |
---|
266 | else: |
---|
267 | value = self.get_value(attr_name, field, value) |
---|
268 | self.set_values(values, self.field_matching_name, value) |
---|
269 | except ValueError as e: |
---|
270 | if silent_failure: |
---|
271 | raise SkipRow() |
---|
272 | else: |
---|
273 | raise e |
---|
274 | if self.cls.is_db_model(): |
---|
275 | for field in composed_fields: |
---|
276 | keys = {} |
---|
277 | for key in field.keys: |
---|
278 | keys[key] = values.pop(key) |
---|
279 | values[self.field_matching_name] = self.get_value(attr_name, field, keys) |
---|
280 | self.create_model_instance(values) |
---|
281 | |
---|
282 | |
---|
283 | class CsvDbModel(CsvModel): |
---|
284 | def validate(self): |
---|
285 | if not self.cls.is_db_model(): |
---|
286 | raise ImproperlyConfigured("dbModel attribute is missing or " |
---|
287 | "wrongly configured in the CsvDbModel " |
---|
288 | "class.") |
---|
289 | |
---|
290 | @classmethod |
---|
291 | def get_fields(cls): |
---|
292 | cls_attrs = super(CsvDbModel, cls).get_fields() |
---|
293 | if len(cls_attrs) != 0: |
---|
294 | raise ImproperlyConfigured("A Db model should not have any csv field defined.") |
---|
295 | attrs = [] |
---|
296 | if cls.is_db_model(): |
---|
297 | model = cls.Meta.dbModel |
---|
298 | for field in model._meta.fields: |
---|
299 | attrs.append((field.name, field)) |
---|
300 | excluded_fields = cls.get_exclusion_fields() |
---|
301 | attrs_filtered = [attr for attr in attrs if attr[0] not in excluded_fields] |
---|
302 | return attrs_filtered |
---|
303 | |
---|
304 | @classmethod |
---|
305 | def get_exclusion_fields(cls): |
---|
306 | list_exclusion = [] |
---|
307 | if hasattr(cls, "Meta") and hasattr(cls.Meta, "exclude"): |
---|
308 | list_exclusion.append(*cls.Meta.exclude) |
---|
309 | if 'id' not in list_exclusion: |
---|
310 | list_exclusion.append('id') |
---|
311 | return list_exclusion |
---|
312 | |
---|
313 | |
---|
314 | class XMLModel(BaseModel): |
---|
315 | _exclude_data_fields = ['root'] |
---|
316 | |
---|
317 | def __init__(self, data, element=None): |
---|
318 | super(XMLModel, self).__init__(data) |
---|
319 | self._base_root = element |
---|
320 | self.construct_obj_from_data(data) |
---|
321 | |
---|
322 | def validate(self):pass |
---|
323 | |
---|
324 | @classmethod |
---|
325 | def get_root_field(cls): |
---|
326 | for field_name, field in cls.get_fields(): |
---|
327 | if type(field) == XMLRootField: |
---|
328 | return field_name, field |
---|
329 | return None |
---|
330 | |
---|
331 | def set_field_value(self, field_name, field, data): |
---|
332 | try: |
---|
333 | self.__dict__[field_name] = field.get_prep_value(data, instance=self) |
---|
334 | except IndexError: |
---|
335 | raise FieldValueMissing(field_name) |
---|
336 | |
---|
337 | def construct_obj_from_data(self, data): |
---|
338 | for field_name, field in self.attrs: |
---|
339 | field.set_root(self._base_root) |
---|
340 | try: |
---|
341 | self.set_field_value(field_name, field, data) |
---|
342 | except Exception as e: |
---|
343 | if self.dont_raise_exception: |
---|
344 | self.errors.append((field_name,e.message)) |
---|
345 | continue |
---|
346 | else: |
---|
347 | raise |
---|
348 | |
---|
349 | @classmethod |
---|
350 | def get_importer(cls, *args): |
---|
351 | return XMLImporter(model=cls) |
---|
352 | |
---|
353 | |
---|
354 | class XMLImporter(object): |
---|
355 | def __init__(self, model): |
---|
356 | self.model = model |
---|
357 | |
---|
358 | def import_data(self, data): |
---|
359 | root_name, root_field = self.model.get_root_field() |
---|
360 | objects = [] |
---|
361 | for element in root_field.get_root(data): |
---|
362 | object = self.model(data, element) |
---|
363 | objects.append(object) |
---|
364 | return objects |
---|
365 | |
---|
366 | |
---|
367 | class LinearLayout(object): |
---|
368 | def process_line(self, lines, line, model, delimiter): |
---|
369 | fields = model.get_fields() |
---|
370 | multiple_index = 0 |
---|
371 | for index, (fieldname, field) in enumerate(fields): |
---|
372 | if hasattr(field, "has_multiple") and field.has_multiple: |
---|
373 | multiple_index = index |
---|
374 | multiple_index_fieldname = fieldname |
---|
375 | break |
---|
376 | if multiple_index: |
---|
377 | if not line[multiple_index:]: |
---|
378 | raise ValueError("No value found for column %s" % multiple_index_fieldname) |
---|
379 | for index, val in enumerate(line[multiple_index:]): |
---|
380 | line_ = line[0:multiple_index] + [line[multiple_index + index]] |
---|
381 | value = model(data=line_, delimiter=delimiter) |
---|
382 | lines.append(value) |
---|
383 | else: |
---|
384 | # Need to keep that to preserve the side effect on line |
---|
385 | value = model(data=line, delimiter=delimiter) |
---|
386 | lines.append(value) |
---|
387 | return value |
---|
388 | |
---|
389 | |
---|
390 | class TabularLayout(object): |
---|
391 | def __init__(self): |
---|
392 | self.line_no = 0 |
---|
393 | self.column_no = 1 |
---|
394 | self.headers = None |
---|
395 | |
---|
396 | def process_line(self, lines, line, model, delimiter): |
---|
397 | value = None |
---|
398 | if self.line_no == 0: |
---|
399 | self.headers = line |
---|
400 | self.line_no += 1 |
---|
401 | else: |
---|
402 | for data in line[1:]: |
---|
403 | inline_data = [line[0], self.headers[self.column_no], data] |
---|
404 | value = model(data=inline_data, delimiter=delimiter) |
---|
405 | lines.append(value) |
---|
406 | self.column_no += 1 |
---|
407 | self.column_no = 1 |
---|
408 | return value |
---|
409 | |
---|
410 | |
---|
411 | class GroupedCsvModel(CsvModel): |
---|
412 | @classmethod |
---|
413 | def get_importer(cls, extra_fields=[]): |
---|
414 | return GroupedCsvImporter(csvModel=cls, extra_fields=extra_fields) |
---|
415 | |
---|
416 | @classmethod |
---|
417 | def has_csv_models(cls): |
---|
418 | return hasattr(cls, "Meta") and hasattr(cls.Meta, "has_header") and cls.Meta.has_header |
---|
419 | |
---|
420 | |
---|
421 | def validate(self): |
---|
422 | if len(self.attrs) != 0: |
---|
423 | raise ImproperlyConfigured("You cannot define fields in \ |
---|
424 | a grouped csv model.") |
---|
425 | if not hasattr(self.cls, "csv_models") or\ |
---|
426 | not isinstance(self.cls.csv_models, list) or\ |
---|
427 | len(self.cls.csv_models) == 0: |
---|
428 | raise ImproperlyConfigured("Group csv models should define a\ |
---|
429 | non empty csv_models list attribute.") |
---|
430 | |
---|
431 | |
---|
432 | class CsvImporter(object): |
---|
433 | def __init__(self, csvModel, extra_fields=[], layout=None): |
---|
434 | self.csvModel = csvModel |
---|
435 | self.extra_fields = extra_fields |
---|
436 | self.dialect = None |
---|
437 | self.delimiter = None |
---|
438 | if not layout: |
---|
439 | if hasattr(self.csvModel, 'Meta') and hasattr(self.csvModel.Meta, 'layout'): |
---|
440 | self.layout = self.csvModel.Meta.layout() |
---|
441 | else: |
---|
442 | self.layout = LinearLayout() |
---|
443 | |
---|
444 | |
---|
445 | def process_extra_fields(self, data, line): |
---|
446 | data_length = len(line) |
---|
447 | if self.extra_fields: |
---|
448 | extra_field_index = 0 |
---|
449 | for value in self.extra_fields: |
---|
450 | if isinstance(value, str): |
---|
451 | line.append(value) |
---|
452 | elif isinstance(value, dict): |
---|
453 | position = value.get('position', len(data) + extra_field_index) |
---|
454 | if not 'value' in value: |
---|
455 | raise CsvException("If a positional extra argument is \ |
---|
456 | defined, a value key should \ |
---|
457 | be present.") |
---|
458 | line.insert(position, value['value']) |
---|
459 | else: |
---|
460 | raise ImproperlyConfigured("Extra field should be a string or a list") |
---|
461 | |
---|
462 | def import_data(self, data): |
---|
463 | lines = [] |
---|
464 | self.get_class_delimiter() |
---|
465 | line_number = 0 |
---|
466 | for line in csv.reader(data, delimiter=self.delimiter): |
---|
467 | self.process_line(data, line, lines, line_number, self.csvModel) |
---|
468 | line_number += 1 |
---|
469 | return lines |
---|
470 | |
---|
471 | |
---|
472 | def process_line(self, data, line, lines, line_number, model): |
---|
473 | self.process_extra_fields(data, line) |
---|
474 | value = None |
---|
475 | try: |
---|
476 | value = self.layout.process_line(lines, line, model, delimiter=self.delimiter) |
---|
477 | except SkipRow: |
---|
478 | pass |
---|
479 | except ForeignKeyFieldError as e: |
---|
480 | raise CsvFieldDataException(line_number, field_error=e.message, model=e.model, value=e.value) |
---|
481 | except ValueError as e: |
---|
482 | if line_number == 0 and self.csvModel.has_header(): |
---|
483 | pass |
---|
484 | else: |
---|
485 | raise CsvDataException(line_number, field_error=e.message) |
---|
486 | except IndexError as e: |
---|
487 | raise CsvDataException(line_number, error="Number of fields invalid") |
---|
488 | return value |
---|
489 | |
---|
490 | |
---|
491 | def get_class_delimiter(self): |
---|
492 | if not self.delimiter and hasattr(self.csvModel, 'Meta') and hasattr(self.csvModel.Meta, 'delimiter'): |
---|
493 | self.delimiter = self.csvModel.Meta.delimiter |
---|
494 | |
---|
495 | def import_from_filename(self, filename): |
---|
496 | csv_file = open(filename) |
---|
497 | return self.import_from_file(csv_file) |
---|
498 | |
---|
499 | def import_from_file(self, csv_file): |
---|
500 | self.get_class_delimiter() |
---|
501 | if not self.delimiter: |
---|
502 | dialect = csv.Sniffer().sniff(csv_file.read(1024)) |
---|
503 | self.delimiter = dialect.delimiter |
---|
504 | csv_file.seek(0) |
---|
505 | return self.import_data(csv_file) |
---|
506 | |
---|
507 | |
---|
508 | def __getitem__(self, item): |
---|
509 | return self.lines[item] |
---|
510 | |
---|
511 | def __iter__(self): |
---|
512 | return self.lines.__iter__() |
---|
513 | |
---|
514 | |
---|
515 | class GroupedCsvImporter(CsvImporter): |
---|
516 | def process_line(self, data, line, lines, line_number, model): |
---|
517 | previous_value = None |
---|
518 | for model in self.csvModel.csv_models: |
---|
519 | if isinstance(model, dict): |
---|
520 | if "use" in model: |
---|
521 | line.insert(0, previous_value.get_object().id) |
---|
522 | previous_value = super(GroupedCsvImporter, self).process_line(data, line, lines, line_number, |
---|
523 | model['model']) |
---|
524 | else: |
---|
525 | super(GroupedCsvImporter, self).process_line(data, line, lines, line_number, model) |
---|