root/trunk/core/src/astar.c

Revision 338, 15.4 KB (checked in by anton, 14 months ago)

See #160

Line 
1/*
2 * A* Shortest path algorithm for PostgreSQL
3 *
4 * Copyright (c) 2006 Anton A. Patrushev, Orkney, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "postgres.h"
23#include "executor/spi.h"
24#include "funcapi.h"
25#include "catalog/pg_type.h"
26
27#include <stdio.h>
28#include <stdlib.h>
29#include <search.h>
30
31#include "astar.h"
32
33//-------------------------------------------------------------------------
34
35/*
36 * Define this to have profiling enabled
37 */
38//#define PROFILE
39
40#ifdef PROFILE
41#include <sys/time.h>
42
43struct timeval prof_astar, prof_store, prof_extract, prof_total;
44long proftime[5];
45long profipts1, profipts2, profopts;
46
47#define profstart(x) do { gettimeofday(&x, NULL); } while (0);
48#define profstop(n, x) do { struct timeval _profstop;   \
49        long _proftime;                         \
50        gettimeofday(&_profstop, NULL);                         \
51        _proftime = ( _profstop.tv_sec*1000000+_profstop.tv_usec) -     \
52                ( x.tv_sec*1000000+x.tv_usec); \
53        elog(NOTICE, \
54                "PRF(%s) %lu (%f ms)", \
55                (n), \
56             _proftime, _proftime / 1000.0);    \
57        } while (0);
58
59#else
60
61#define profstart(x) do { } while (0);
62#define profstop(n, x) do { } while (0);
63
64#endif // PROFILE
65
66
67//-------------------------------------------------------------------------
68
69Datum shortest_path_astar(PG_FUNCTION_ARGS);
70
71#undef DEBUG
72//#define DEBUG 1
73
74#ifdef DEBUG
75#define DBG(format, arg...)                     \
76    elog(NOTICE, format , ## arg)
77#else
78#define DBG(format, arg...) do { ; } while (0)
79#endif
80
81// The number of tuples to fetch from the SPI cursor at each iteration
82#define TUPLIMIT 1000
83
84static char *
85text2char(text *in)
86{
87  char *out = palloc(VARSIZE(in));
88
89  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
90  out[VARSIZE(in) - VARHDRSZ] = '\0';
91  return out;
92}
93
94static int
95finish(int code, int ret)
96{
97  code = SPI_finish();
98  if (code  != SPI_OK_FINISH )
99    {
100      elog(ERROR,"couldn't disconnect from SPI");
101      return -1 ;
102    }
103 
104  return ret;
105}
106 
107typedef struct edge_astar_columns
108{
109  int id;
110  int source;
111  int target;
112  int cost;
113  int reverse_cost;
114  int s_x;
115  int s_y;
116  int t_x;
117  int t_y;
118} edge_astar_columns_t;
119
120
121static int
122fetch_edge_astar_columns(SPITupleTable *tuptable,
123                         edge_astar_columns_t *edge_columns,
124                         bool has_reverse_cost)
125{
126  edge_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "id");
127  edge_columns->source = SPI_fnumber(SPI_tuptable->tupdesc, "source");
128  edge_columns->target = SPI_fnumber(SPI_tuptable->tupdesc, "target");
129  edge_columns->cost = SPI_fnumber(SPI_tuptable->tupdesc, "cost");
130  if (edge_columns->id == SPI_ERROR_NOATTRIBUTE ||
131      edge_columns->source == SPI_ERROR_NOATTRIBUTE ||
132      edge_columns->target == SPI_ERROR_NOATTRIBUTE ||
133      edge_columns->cost == SPI_ERROR_NOATTRIBUTE)
134    {
135      elog(ERROR, "Error, query must return columns "
136           "'id', 'source', 'target' and 'cost'");
137      return -1;
138    }
139
140  if (SPI_gettypeid(SPI_tuptable->tupdesc,
141                    edge_columns->source) != INT4OID ||
142      SPI_gettypeid(SPI_tuptable->tupdesc,
143                    edge_columns->target) != INT4OID ||
144      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->cost) != FLOAT8OID)
145    {
146      elog(ERROR, "Error, columns 'source', 'target' must be of type int4, "
147           "'cost' must be of type float8");
148      return -1;
149    }
150
151  DBG("columns: id %i source %i target %i cost %i",
152      edge_columns->id, edge_columns->source,
153      edge_columns->target, edge_columns->cost);
154
155  if (has_reverse_cost)
156    {
157      edge_columns->reverse_cost = SPI_fnumber(SPI_tuptable->tupdesc,
158                                               "reverse_cost");
159
160      if (edge_columns->reverse_cost == SPI_ERROR_NOATTRIBUTE)
161        {
162          elog(ERROR, "Error, reverse_cost is used, but query did't return "
163               "'reverse_cost' column");
164          return -1;
165        }
166
167      if (SPI_gettypeid(SPI_tuptable->tupdesc,
168                        edge_columns->reverse_cost) != FLOAT8OID)
169        {
170          elog(ERROR, "Error, columns 'reverse_cost' must be of type float8");
171          return -1;
172        }
173
174      DBG("columns: reverse_cost cost %i", edge_columns->reverse_cost);
175    }
176
177  edge_columns->s_x = SPI_fnumber(SPI_tuptable->tupdesc, "x1");
178  edge_columns->s_y = SPI_fnumber(SPI_tuptable->tupdesc, "y1");
179  edge_columns->t_x = SPI_fnumber(SPI_tuptable->tupdesc, "x2");
180  edge_columns->t_y = SPI_fnumber(SPI_tuptable->tupdesc, "y2");
181
182  if (edge_columns->s_x == SPI_ERROR_NOATTRIBUTE ||
183      edge_columns->s_y == SPI_ERROR_NOATTRIBUTE ||
184      edge_columns->t_x == SPI_ERROR_NOATTRIBUTE ||
185      edge_columns->t_y == SPI_ERROR_NOATTRIBUTE)
186    {
187      elog(ERROR, "Error, query must return columns "
188           "'x1', 'x2', 'y1' and 'y2'");
189      return -1;
190    }
191
192  DBG("columns: x1 %i y1 %i x2 %i y2 %i",
193      edge_columns->s_x, edge_columns->s_y,
194      edge_columns->t_x,edge_columns->t_y);
195   
196  return 0;
197}
198
199static void
200fetch_edge_astar(HeapTuple *tuple, TupleDesc *tupdesc,
201                 edge_astar_columns_t *edge_columns,
202                 edge_astar_t *target_edge)
203{
204  Datum binval;
205  bool isnull;
206
207  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->id, &isnull);
208  if (isnull)
209    elog(ERROR, "id contains a null value");
210  target_edge->id = DatumGetInt32(binval);
211
212  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->source, &isnull);
213  if (isnull)
214    elog(ERROR, "source contains a null value");
215  target_edge->source = DatumGetInt32(binval);
216
217  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->target, &isnull);
218  if (isnull)
219    elog(ERROR, "target contains a null value");
220  target_edge->target = DatumGetInt32(binval);
221
222  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->cost, &isnull);
223  if (isnull)
224    elog(ERROR, "cost contains a null value");
225  target_edge->cost = DatumGetFloat8(binval);
226
227  if (edge_columns->reverse_cost != -1)
228    {
229      binval = SPI_getbinval(*tuple, *tupdesc,
230                             edge_columns->reverse_cost, &isnull);
231      if (isnull)
232        elog(ERROR, "reverse_cost contains a null value");
233      target_edge->reverse_cost =  DatumGetFloat8(binval);
234    }
235
236  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->s_x, &isnull);
237  if (isnull)
238    elog(ERROR, "source x contains a null value");
239  target_edge->s_x = DatumGetFloat8(binval);
240
241  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->s_y, &isnull);
242  if (isnull)
243    elog(ERROR, "source y contains a null value");
244  target_edge->s_y = DatumGetFloat8(binval);
245 
246  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->t_x, &isnull);
247  if (isnull)
248    elog(ERROR, "target x contains a null value");
249  target_edge->t_x = DatumGetFloat8(binval);
250 
251  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->t_y, &isnull);
252  if (isnull)
253    elog(ERROR, "target y contains a null value");
254  target_edge->t_y = DatumGetFloat8(binval);
255}
256
257
258static int compute_shortest_path_astar(char* sql, int source_vertex_id,
259                                       int target_vertex_id, bool directed,
260                                       bool has_reverse_cost,
261                                       path_element_t **path, int *path_count)
262{
263 
264  int SPIcode;
265  void *SPIplan;
266  Portal SPIportal;
267  bool moredata = TRUE;
268  int ntuples;
269  edge_astar_t *edges = NULL;
270  int total_tuples = 0;
271 
272  int v_max_id=0;
273  int v_min_id=INT_MAX; 
274   
275  edge_astar_columns_t edge_columns = {id: -1, source: -1, target: -1,
276                                       cost: -1, reverse_cost: -1,
277                                       s_x: -1, s_y: -1, t_x: -1, t_y: -1};
278  char *err_msg;
279  int ret = -1;
280  register int z;
281 
282  int s_count=0;
283  int t_count=0;
284 
285  struct vItem
286  {
287    int id;
288    int key;
289  };
290 
291  DBG("start shortest_path_astar\n");
292       
293  SPIcode = SPI_connect();
294  if (SPIcode  != SPI_OK_CONNECT)
295    {
296      elog(ERROR, "shortest_path_astar: couldn't open a connection to SPI");
297      return -1;
298    }
299
300  SPIplan = SPI_prepare(sql, 0, NULL);
301  if (SPIplan  == NULL)
302    {
303      elog(ERROR, "shortest_path_astar: couldn't create query plan via SPI");
304      return -1;
305    }
306
307  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
308    {
309      elog(ERROR, "shortest_path_astar: SPI_cursor_open('%s') returns NULL",
310           sql);
311      return -1;
312    }
313
314  while (moredata == TRUE)
315    {
316      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
317
318      if (edge_columns.id == -1)
319        {
320          if (fetch_edge_astar_columns(SPI_tuptable, &edge_columns,
321                                       has_reverse_cost) == -1)
322            return finish(SPIcode, ret);
323        }
324
325      ntuples = SPI_processed;
326      total_tuples += ntuples;
327      if (!edges)
328        edges = palloc(total_tuples * sizeof(edge_astar_t));
329      else
330        edges = repalloc(edges, total_tuples * sizeof(edge_astar_t));
331
332      if (edges == NULL)
333        {
334          elog(ERROR, "Out of memory");
335          return finish(SPIcode, ret);
336        }
337
338      if (ntuples > 0)
339        {
340          int t;
341          SPITupleTable *tuptable = SPI_tuptable;
342          TupleDesc tupdesc = SPI_tuptable->tupdesc;
343         
344          for (t = 0; t < ntuples; t++)
345            {
346              HeapTuple tuple = tuptable->vals[t];
347              fetch_edge_astar(&tuple, &tupdesc, &edge_columns,
348                               &edges[total_tuples - ntuples + t]);
349            }
350          SPI_freetuptable(tuptable);
351        }
352      else
353        {
354          moredata = FALSE;
355        }
356    }
357   
358  //defining min and max vertex id
359     
360  DBG("Total %i tuples", total_tuples);
361   
362  for(z=0; z<total_tuples; z++)
363  {
364    if(edges[z].source<v_min_id)
365    v_min_id=edges[z].source;
366 
367    if(edges[z].source>v_max_id)
368      v_max_id=edges[z].source;
369                                   
370    if(edges[z].target<v_min_id)
371      v_min_id=edges[z].target;
372
373    if(edges[z].target>v_max_id)
374      v_max_id=edges[z].target;     
375                                                                       
376    DBG("%i <-> %i", v_min_id, v_max_id);
377                                                               
378  }
379
380  //:::::::::::::::::::::::::::::::::::: 
381  //:: reducing vertex id (renumbering)
382  //::::::::::::::::::::::::::::::::::::
383  for(z=0; z<total_tuples; z++)
384  {
385
386    //check if edges[] contains source and target
387    if(edges[z].source == source_vertex_id ||
388       edges[z].target == source_vertex_id)
389      ++s_count;
390    if(edges[z].source == target_vertex_id ||
391       edges[z].target == target_vertex_id)
392      ++t_count;
393
394    edges[z].source-=v_min_id;
395    edges[z].target-=v_min_id;
396    DBG("%i - %i", edges[z].source, edges[z].target);
397                                                               
398
399  }
400   
401  DBG("Total %i tuples", total_tuples);
402
403  if(s_count == 0)
404  {
405    elog(ERROR, "Start vertex was not found.");
406    return -1;
407  }
408             
409  if(t_count == 0)
410  {
411    elog(ERROR, "Target vertex was not found.");
412    return -1;
413  }
414                           
415  DBG("Total %i tuples", total_tuples);
416
417  profstop("extract", prof_extract);
418  profstart(prof_astar);
419 
420  DBG("Calling boost_astar <%i>\n", total_tuples);
421
422  // calling C++ A* function   
423  ret = boost_astar(edges, total_tuples, source_vertex_id-v_min_id,
424                    target_vertex_id-v_min_id,
425                    directed, has_reverse_cost,
426                    path, path_count, &err_msg);
427
428  DBG("SIZE %i\n",*path_count);
429
430  DBG("ret =  %i\n",ret);
431 
432  //::::::::::::::::::::::::::::::::
433  //:: restoring original vertex id
434  //::::::::::::::::::::::::::::::::
435  for(z=0;z<*path_count;z++)
436  {
437    //DBG("vetex %i\n",(*path)[z].vertex_id);
438    (*path)[z].vertex_id+=v_min_id;
439  } 
440
441  profstop("astar", prof_astar);
442  profstart(prof_store);
443
444  if (ret < 0)
445    {
446      //elog(ERROR, "Error computing path: %s", err_msg);
447      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED),
448        errmsg("Error computing path: %s", err_msg)));
449    }
450  return finish(SPIcode, ret);
451}
452
453
454PG_FUNCTION_INFO_V1(shortest_path_astar);
455Datum
456shortest_path_astar(PG_FUNCTION_ARGS)
457{
458  FuncCallContext     *funcctx;
459  int                  call_cntr;
460  int                  max_calls;
461  TupleDesc            tuple_desc;
462  path_element_t      *path;
463 
464  /* stuff done only on the first call of the function */
465  if (SRF_IS_FIRSTCALL())
466    {
467      MemoryContext   oldcontext;
468      int path_count = 0;
469      int ret;
470
471      // XXX profiling messages are not thread safe
472      profstart(prof_total);
473      profstart(prof_extract);
474     
475      /* create a function context for cross-call persistence */
476      funcctx = SRF_FIRSTCALL_INIT();
477     
478      /* switch to memory context appropriate for multiple function calls */
479      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
480
481
482      ret = compute_shortest_path_astar(text2char(PG_GETARG_TEXT_P(0)),
483                                        PG_GETARG_INT32(1),
484                                        PG_GETARG_INT32(2),
485                                        PG_GETARG_BOOL(3),
486                                        PG_GETARG_BOOL(4),
487                                        &path, &path_count);
488
489#ifdef DEBUG
490      DBG("Ret is %i", ret);
491      if (ret >= 0)
492        {
493          int i;
494          for (i = 0; i < path_count; i++)
495            {
496              DBG("Step # %i vertex_id  %i ", i, path[i].vertex_id);
497              DBG("        edge_id    %i ", path[i].edge_id);
498              DBG("        cost       %f ", path[i].cost);
499            }
500        }
501#endif
502
503      /* total number of tuples to be returned */
504      DBG("Conting tuples number\n");
505      funcctx->max_calls = path_count;
506      funcctx->user_fctx = path;
507     
508      DBG("Path count %i", path_count);
509     
510      funcctx->tuple_desc =
511        BlessTupleDesc(RelationNameGetTupleDesc("path_result"));
512
513      MemoryContextSwitchTo(oldcontext);
514    }
515
516  /* stuff done on every call of the function */
517  DBG("Strange stuff doing\n");
518
519  funcctx = SRF_PERCALL_SETUP();
520 
521  call_cntr = funcctx->call_cntr;
522  max_calls = funcctx->max_calls;
523  tuple_desc = funcctx->tuple_desc;
524  path = (path_element_t*) funcctx->user_fctx;
525 
526  DBG("Trying to allocate some memory\n");
527
528  if (call_cntr < max_calls)    /* do when there is more left to send */
529    {
530      HeapTuple    tuple;
531      Datum        result;
532      Datum *values;
533      char* nulls;
534     
535      /* This will work for some compilers. If it crashes with segfault, try to change the following block with this one   
536 
537      values = palloc(4 * sizeof(Datum));
538      nulls = palloc(4 * sizeof(char));
539 
540      values[0] = call_cntr;
541      nulls[0] = ' ';
542      values[1] = Int32GetDatum(path[call_cntr].vertex_id);
543      nulls[1] = ' ';
544       values[2] = Int32GetDatum(path[call_cntr].edge_id);
545      nulls[2] = ' ';
546      values[3] = Float8GetDatum(path[call_cntr].cost);
547      nulls[3] = ' ';
548      */
549   
550      values = palloc(3 * sizeof(Datum));
551      nulls = palloc(3 * sizeof(char));
552 
553      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
554      nulls[0] = ' ';
555      values[1] = Int32GetDatum(path[call_cntr].edge_id);
556      nulls[1] = ' ';
557      values[2] = Float8GetDatum(path[call_cntr].cost);
558      nulls[2] = ' ';
559                 
560      DBG("Heap making\n");
561     
562      tuple = heap_formtuple(tuple_desc, values, nulls);
563     
564      DBG("Datum making\n");
565     
566      /* make the tuple into a datum */
567      result = HeapTupleGetDatum(tuple);
568     
569
570      DBG("Trying to free some memory\n");
571   
572      /* clean up (this is not really necessary) */
573      pfree(values);
574      pfree(nulls);
575     
576      SRF_RETURN_NEXT(funcctx, result);
577    }
578  else    /* do when there is no more left */
579    {
580      profstop("store", prof_store);
581      profstop("total", prof_total);
582#ifdef PROFILE
583      elog(NOTICE, "_________");
584#endif
585      SRF_RETURN_DONE(funcctx);
586    }
587}
Note: See TracBrowser for help on using the browser.