root/trunk/extra/driving_distance/src/drivedist.c

Revision 338, 13.4 KB (checked in by anton, 14 months ago)

See #160

Line 
1/*
2 * Finding the Driving Distance (isochrone/isodist) for PostgreSQL
3 *
4 * Copyright (c) 2006 Mario H. Basa, Orkney, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "postgres.h"
23#include "executor/spi.h"
24#include "funcapi.h"
25#include "catalog/pg_type.h"
26
27#include "drivedist.h"
28
29//---------------------------------------------------------------------------
30
31/*
32 * Define this to have profiling enabled
33 */
34//#define PROFILE
35
36#ifdef PROFILE
37#include <sys/time.h>
38
39struct timeval prof_dijkstra, prof_store, prof_extract, prof_total;
40long proftime[5];
41long profipts1, profipts2, profopts;
42#define profstart(x) do { gettimeofday(&x, NULL); } while (0);
43#define profstop(n, x) do { struct timeval _profstop;   \
44        long _proftime;                         \
45        gettimeofday(&_profstop, NULL);                         \
46        _proftime = ( _profstop.tv_sec*1000000+_profstop.tv_usec) -     \
47                ( x.tv_sec*1000000+x.tv_usec); \
48        elog(NOTICE, \
49                "PRF(%s) %lu (%f ms)", \
50                (n), \
51             _proftime, _proftime / 1000.0);    \
52        } while (0);
53
54#else
55
56#define profstart(x) do { } while (0);
57#define profstop(n, x) do { } while (0);
58
59#endif // PROFILE
60
61
62//----------------------------------------------------------------------------
63
64Datum driving_distance(PG_FUNCTION_ARGS);
65
66#undef DEBUG
67//#define DEBUG 1
68
69#ifdef DEBUG
70#define DBG(format, arg...)                     \
71    elog(NOTICE, format , ## arg)
72#else
73#define DBG(format, arg...) do { ; } while (0)
74#endif
75
76// The number of tuples to fetch from the SPI cursor at each iteration
77#define TUPLIMIT 1000
78
79static char *
80text2char(text *in)
81{
82  char *out = palloc(VARSIZE(in));
83
84  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
85  out[VARSIZE(in) - VARHDRSZ] = '\0';
86  return out;
87}
88
89static int
90finish(int code, int ret)
91{
92  code = SPI_finish();
93  if (code  != SPI_OK_FINISH )
94  {
95    elog(ERROR,"couldn't disconnect from SPI");
96    return -1 ;
97  }
98                       
99  return ret;
100}
101                         
102
103typedef struct edge_columns
104{
105  int id;
106  int source;
107  int target;
108  int cost;
109  int reverse_cost;
110} edge_columns_t;
111
112static int
113fetch_edge_columns(SPITupleTable *tuptable, edge_columns_t *edge_columns,
114                   bool has_reverse_cost)
115{
116  edge_columns->id     = SPI_fnumber(SPI_tuptable->tupdesc, "id");
117  edge_columns->source = SPI_fnumber(SPI_tuptable->tupdesc, "source");
118  edge_columns->target = SPI_fnumber(SPI_tuptable->tupdesc, "target");
119  edge_columns->cost   = SPI_fnumber(SPI_tuptable->tupdesc, "cost");
120
121  if (edge_columns->id     == SPI_ERROR_NOATTRIBUTE ||
122      edge_columns->source == SPI_ERROR_NOATTRIBUTE ||
123      edge_columns->target == SPI_ERROR_NOATTRIBUTE ||
124      edge_columns->cost   == SPI_ERROR_NOATTRIBUTE)  {
125    elog(ERROR, "Error, query must return columns "
126         "'id', 'source', 'target' and 'cost'");
127    return -1;
128  }
129 
130  if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->source) != INT4OID ||
131      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->target) != INT4OID ||
132      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->cost) != FLOAT8OID) {
133    elog(ERROR, "Error, columns 'source', 'target' must be of type int4, 'cost' must be of type float8");
134    return -1;
135  }
136 
137  DBG("columns: id %i source %i target %i cost %i",
138      edge_columns->id, edge_columns->source,
139      edge_columns->target, edge_columns->cost);
140 
141  if (has_reverse_cost) {
142    edge_columns->reverse_cost = SPI_fnumber(SPI_tuptable->tupdesc,
143                                             "reverse_cost");
144   
145    if (edge_columns->reverse_cost == SPI_ERROR_NOATTRIBUTE)  {
146      elog(ERROR, "Error, reverse_cost is used, but query did't return "
147           "'reverse_cost' column");
148      return -1;
149    }
150     
151    if (SPI_gettypeid(SPI_tuptable->tupdesc,
152                      edge_columns->reverse_cost) != FLOAT8OID) {
153      elog(ERROR, "Error, columns 'reverse_cost' must be of type float8");
154      return -1;
155    }
156     
157    DBG("columns: reverse_cost cost %i", edge_columns->reverse_cost);
158  }
159 
160  return 0;
161}
162
163static void
164fetch_edge(HeapTuple *tuple, TupleDesc *tupdesc, edge_columns_t *edge_columns,
165           edge_t *target_edge)
166{
167  Datum binval;
168  bool isnull;
169 
170  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->id, &isnull);
171 
172  if (isnull)
173    elog(ERROR, "id contains a null value");
174  target_edge->id = DatumGetInt32(binval);
175 
176  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->source, &isnull);
177 
178  if (isnull)
179    elog(ERROR, "source contains a null value");
180
181  target_edge->source = DatumGetInt32(binval);
182 
183  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->target, &isnull);
184
185  if (isnull)
186    elog(ERROR, "target contains a null value");
187 
188  target_edge->target = DatumGetInt32(binval);
189 
190  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->cost, &isnull);
191 
192  if (isnull)
193    elog(ERROR, "cost contains a null value");
194 
195  target_edge->cost = DatumGetFloat8(binval);
196 
197  if (edge_columns->reverse_cost != -1) {
198    binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->reverse_cost,
199                           &isnull);
200   
201    if (isnull)
202      elog(ERROR, "reverse_cost contains a null value");
203    target_edge->reverse_cost =  DatumGetFloat8(binval);
204  }
205}
206
207
208static int compute_driving_distance(char* sql, int source_vertex_id,
209                                    float8 distance, bool directed,
210                                    bool has_reverse_cost,
211                                    path_element_t **path, int *path_count)
212{
213  int SPIcode;
214  void *SPIplan;
215  Portal SPIportal;
216  bool moredata = TRUE;
217  int ntuples;
218  edge_t *edges = NULL;
219  int total_tuples = 0;
220  edge_columns_t edge_columns = {id: -1, source: -1, target: -1,
221                                 cost: -1, reverse_cost: -1};
222
223  int v_max_id=0;
224  int v_min_id=INT_MAX;
225
226  char *err_msg;
227  int ret = -1;
228 
229  int s_count = 0;
230 
231  register int z;
232 
233  DBG("start driving_distance\n");
234 
235  SPIcode = SPI_connect();
236  if (SPIcode  != SPI_OK_CONNECT) {
237    elog(ERROR, "driving_distance: couldn't open a connection to SPI");
238    return -1;
239  }
240 
241  SPIplan = SPI_prepare(sql, 0, NULL);
242
243  if (SPIplan  == NULL) {
244    elog(ERROR, "driving_distance: couldn't create query plan via SPI");
245    return -1;
246  }
247
248  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL,
249                                   NULL, true)) == NULL) { 
250    elog(ERROR, "driving_distance: SPI_cursor_open('%s') returns NULL", sql);
251    return -1;
252  }
253
254  while (moredata == TRUE) {
255    SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
256
257
258    if (edge_columns.id == -1)  {
259      if (fetch_edge_columns(SPI_tuptable, &edge_columns,
260                             has_reverse_cost) == -1)
261        return finish(SPIcode, ret);
262    }
263
264    ntuples = SPI_processed;
265    total_tuples += ntuples;
266    if (!edges)
267      edges = palloc(total_tuples * sizeof(edge_t));
268    else
269      edges = repalloc(edges, total_tuples * sizeof(edge_t));
270
271    if (edges == NULL) {
272      elog(ERROR, "Out of memory");
273      return finish(SPIcode, ret);
274    }
275
276    if (ntuples > 0) {
277      int t;
278      SPITupleTable *tuptable = SPI_tuptable;
279      TupleDesc tupdesc = SPI_tuptable->tupdesc;
280     
281      for (t = 0; t < ntuples; t++) {
282        HeapTuple tuple = tuptable->vals[t];
283        fetch_edge(&tuple, &tupdesc, &edge_columns,
284                   &edges[total_tuples - ntuples + t]);
285      }
286      SPI_freetuptable(tuptable);
287    }
288    else {
289      moredata = FALSE;
290    }
291  }
292
293
294  //defining min and max vertex id
295     
296  DBG("Total %i tuples", total_tuples);
297   
298  for(z=0; z<total_tuples; z++)
299  {
300    if(edges[z].source<v_min_id)
301      v_min_id=edges[z].source;
302
303    if(edges[z].source>v_max_id)
304      v_max_id=edges[z].source;
305                                           
306    if(edges[z].target<v_min_id)
307      v_min_id=edges[z].target;
308
309    if(edges[z].target>v_max_id)
310      v_max_id=edges[z].target;     
311                                                                       
312    DBG("%i <-> %i", v_min_id, v_max_id);
313                               
314  }
315
316  //:::::::::::::::::::::::::::::::::::: 
317  //:: reducing vertex id (renumbering)
318  //::::::::::::::::::::::::::::::::::::
319  for(z=0; z<total_tuples; z++)
320  {
321    //check if edges[] contains source
322    if(edges[z].source == source_vertex_id ||
323       edges[z].target == source_vertex_id)
324      ++s_count;
325
326    edges[z].source-=v_min_id;
327    edges[z].target-=v_min_id;
328    DBG("%i - %i", edges[z].source, edges[z].target);     
329  }
330
331  if(s_count == 0)
332  {
333    elog(ERROR, "Start vertex was not found.");
334    return -1;
335  }
336                         
337  source_vertex_id -= v_min_id;
338
339  DBG("Calling boost_dijkstra\n");
340       
341  profstop("extract", prof_extract);
342  profstart(prof_dijkstra);
343 
344  ret = boost_dijkstra_dist(edges, total_tuples, source_vertex_id,
345                            distance, directed, has_reverse_cost,
346                            path, path_count, &err_msg);
347   
348  profstop("dijkstra", prof_dijkstra);
349  profstart(prof_store);
350   
351  //::::::::::::::::::::::::::::::::
352  //:: restoring original vertex id
353  //::::::::::::::::::::::::::::::::
354  for(z=0;z<*path_count;z++)
355  {
356    //DBG("vetex %i\n",(*path)[z].vertex_id);
357    (*path)[z].vertex_id+=v_min_id;
358  }
359
360  if (ret < 0) {
361    //elog(ERROR, "Error computing path: %s", err_msg);
362    ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED),
363                    errmsg("Error computing path: %s", err_msg)));
364  }
365   
366  return finish(SPIcode, ret);
367}
368
369
370PG_FUNCTION_INFO_V1(driving_distance);
371Datum
372driving_distance(PG_FUNCTION_ARGS)
373{
374  FuncCallContext     *funcctx;
375  int                  call_cntr;
376  int                  max_calls;
377  TupleDesc            tuple_desc;
378  path_element_t      *path;
379
380  /* stuff done only on the first call of the function */
381  if (SRF_IS_FIRSTCALL()) {
382    MemoryContext   oldcontext;
383    int path_count = 0;
384    int ret;
385   
386    // XXX profiling messages are not thread safe
387    profstart(prof_total);
388    profstart(prof_extract);
389   
390    /* create a function context for cross-call persistence */
391    funcctx = SRF_FIRSTCALL_INIT();
392   
393    /* switch to memory context appropriate for multiple function calls */
394    oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
395   
396    ret = compute_driving_distance(text2char(PG_GETARG_TEXT_P(0)), // sql
397                                PG_GETARG_INT32(1),   // source vertex
398                                PG_GETARG_FLOAT8(2),  // distance or time
399                                PG_GETARG_BOOL(3),
400                                PG_GETARG_BOOL(4), &path, &path_count);
401
402#ifdef DEBUG
403    DBG("Ret is %i", ret);
404    if (ret >= 0) {
405      int i;
406      for (i = 0; i < path_count; i++) {
407        DBG("Step %i vertex_id  %i ", i, path[i].vertex_id);
408        DBG("        edge_id    %i ", path[i].edge_id);
409        DBG("        cost       %f ", path[i].cost);
410      }
411    }
412#endif
413
414    /* total number of tuples to be returned */
415    funcctx->max_calls = path_count;
416    funcctx->user_fctx = path;
417
418    funcctx->tuple_desc = BlessTupleDesc(
419                             RelationNameGetTupleDesc("path_result"));
420   
421    MemoryContextSwitchTo(oldcontext);
422  }
423 
424  /* stuff done on every call of the function */
425  funcctx = SRF_PERCALL_SETUP();
426
427  call_cntr = funcctx->call_cntr;
428  max_calls = funcctx->max_calls;
429  tuple_desc = funcctx->tuple_desc;
430  path = (path_element_t*) funcctx->user_fctx;
431 
432  if (call_cntr < max_calls) {   /* do when there is more left to send */
433    HeapTuple    tuple;
434    Datum        result;
435    Datum *values;
436    char* nulls;
437   
438    /* This will work for some compilers. If it crashes with segfault, try to change the following block with this one   
439
440    values = palloc(4 * sizeof(Datum));
441    nulls = palloc(4 * sizeof(char));
442
443    values[0] = call_cntr;
444    nulls[0] = ' ';
445    values[1] = Int32GetDatum(path[call_cntr].vertex_id);
446    nulls[1] = ' ';
447    values[2] = Int32GetDatum(path[call_cntr].edge_id);
448    nulls[2] = ' ';
449    values[3] = Float8GetDatum(path[call_cntr].cost);
450    nulls[3] = ' ';
451    */
452   
453    values = palloc(3 * sizeof(Datum));
454    nulls = palloc(3 * sizeof(char));
455
456    values[0] = Int32GetDatum(path[call_cntr].vertex_id);
457    nulls[0] = ' ';
458    values[1] = Int32GetDatum(path[call_cntr].edge_id);
459    nulls[1] = ' ';
460    values[2] = Float8GetDatum(path[call_cntr].cost);
461    nulls[2] = ' ';
462     
463    tuple = heap_formtuple(tuple_desc, values, nulls);
464   
465
466    /* make the tuple into a datum */
467    result = HeapTupleGetDatum(tuple);
468   
469    /* clean up (this is not really necessary) */
470    pfree(values);
471    pfree(nulls);
472
473    SRF_RETURN_NEXT(funcctx, result);
474  }
475  else {    /* do when there is no more left */
476    profstop("store", prof_store);
477    profstop("total", prof_total);
478#ifdef PROFILE
479    elog(NOTICE, "_________");
480#endif
481    DBG("Returning value");
482
483    SRF_RETURN_DONE(funcctx);
484  }
485}
Note: See TracBrowser for help on using the browser.