root/trunk/core/src/dijkstra.c

Revision 338, 12.4 KB (checked in by anton, 14 months ago)

See #160

Line 
1/*
2 * Shortest path algorithm for PostgreSQL
3 *
4 * Copyright (c) 2005 Sylvain Pasche
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "postgres.h"
23#include "executor/spi.h"
24#include "funcapi.h"
25#include "catalog/pg_type.h"
26
27#include "fmgr.h"
28
29#include "dijkstra.h"
30
31Datum shortest_path(PG_FUNCTION_ARGS);
32
33#undef DEBUG
34//#define DEBUG 1
35
36#ifdef DEBUG
37#define DBG(format, arg...)                     \
38    elog(NOTICE, format , ## arg)
39#else
40#define DBG(format, arg...) do { ; } while (0)
41#endif
42
43// The number of tuples to fetch from the SPI cursor at each iteration
44#define TUPLIMIT 1000
45
46#ifdef PG_MODULE_MAGIC
47PG_MODULE_MAGIC;
48#endif
49
50static char *
51text2char(text *in)
52{
53  char *out = palloc(VARSIZE(in));
54
55  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
56  out[VARSIZE(in) - VARHDRSZ] = '\0';
57  return out;
58}
59
60static int
61finish(int code, int ret)
62{
63  code = SPI_finish();
64  if (code  != SPI_OK_FINISH )
65  {
66    elog(ERROR,"couldn't disconnect from SPI");
67    return -1 ;
68  }                     
69  return ret;
70}
71                         
72typedef struct edge_columns
73{
74  int id;
75  int source;
76  int target;
77  int cost;
78  int reverse_cost;
79} edge_columns_t;
80
81static int
82fetch_edge_columns(SPITupleTable *tuptable, edge_columns_t *edge_columns,
83                   bool has_reverse_cost)
84{
85  edge_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "id");
86  edge_columns->source = SPI_fnumber(SPI_tuptable->tupdesc, "source");
87  edge_columns->target = SPI_fnumber(SPI_tuptable->tupdesc, "target");
88  edge_columns->cost = SPI_fnumber(SPI_tuptable->tupdesc, "cost");
89  if (edge_columns->id == SPI_ERROR_NOATTRIBUTE ||
90      edge_columns->source == SPI_ERROR_NOATTRIBUTE ||
91      edge_columns->target == SPI_ERROR_NOATTRIBUTE ||
92      edge_columns->cost == SPI_ERROR_NOATTRIBUTE)
93    {
94      elog(ERROR, "Error, query must return columns "
95           "'id', 'source', 'target' and 'cost'");
96      return -1;
97    }
98
99  if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->source) != INT4OID ||
100      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->target) != INT4OID ||
101      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->cost) != FLOAT8OID)
102    {
103      elog(ERROR, "Error, columns 'source', 'target' must be of type int4, 'cost' must be of type float8");
104      return -1;
105    }
106
107  DBG("columns: id %i source %i target %i cost %i",
108      edge_columns->id, edge_columns->source,
109      edge_columns->target, edge_columns->cost);
110
111  if (has_reverse_cost)
112    {
113      edge_columns->reverse_cost = SPI_fnumber(SPI_tuptable->tupdesc,
114                                               "reverse_cost");
115
116      if (edge_columns->reverse_cost == SPI_ERROR_NOATTRIBUTE)
117        {
118          elog(ERROR, "Error, reverse_cost is used, but query did't return "
119               "'reverse_cost' column");
120          return -1;
121        }
122
123      if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->reverse_cost)
124          != FLOAT8OID)
125        {
126          elog(ERROR, "Error, columns 'reverse_cost' must be of type float8");
127          return -1;
128        }
129
130      DBG("columns: reverse_cost cost %i", edge_columns->reverse_cost);
131    }
132   
133  return 0;
134}
135
136static void
137fetch_edge(HeapTuple *tuple, TupleDesc *tupdesc,
138           edge_columns_t *edge_columns, edge_t *target_edge)
139{
140  Datum binval;
141  bool isnull;
142
143  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->id, &isnull);
144  if (isnull)
145    elog(ERROR, "id contains a null value");
146  target_edge->id = DatumGetInt32(binval);
147
148  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->source, &isnull);
149  if (isnull)
150    elog(ERROR, "source contains a null value");
151  target_edge->source = DatumGetInt32(binval);
152
153  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->target, &isnull);
154  if (isnull)
155    elog(ERROR, "target contains a null value");
156  target_edge->target = DatumGetInt32(binval);
157
158  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->cost, &isnull);
159  if (isnull)
160    elog(ERROR, "cost contains a null value");
161  target_edge->cost = DatumGetFloat8(binval);
162
163  if (edge_columns->reverse_cost != -1)
164    {
165      binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->reverse_cost,
166                             &isnull);
167      if (isnull)
168        elog(ERROR, "reverse_cost contains a null value");
169      target_edge->reverse_cost =  DatumGetFloat8(binval);
170    }
171}
172
173
174static int compute_shortest_path(char* sql, int start_vertex,
175                                 int end_vertex, bool directed,
176                                 bool has_reverse_cost,
177                                 path_element_t **path, int *path_count)
178{
179
180  int SPIcode;
181  void *SPIplan;
182  Portal SPIportal;
183  bool moredata = TRUE;
184  int ntuples;
185  edge_t *edges = NULL;
186  int total_tuples = 0;
187  edge_columns_t edge_columns = {id: -1, source: -1, target: -1,
188                                 cost: -1, reverse_cost: -1};
189  int v_max_id=0;
190  int v_min_id=INT_MAX;
191
192  int s_count = 0;
193  int t_count = 0;
194
195  char *err_msg;
196  int ret = -1;
197  register int z;
198
199  DBG("start shortest_path\n");
200       
201  SPIcode = SPI_connect();
202  if (SPIcode  != SPI_OK_CONNECT)
203    {
204      elog(ERROR, "shortest_path: couldn't open a connection to SPI");
205      return -1;
206    }
207
208  SPIplan = SPI_prepare(sql, 0, NULL);
209  if (SPIplan  == NULL)
210    {
211      elog(ERROR, "shortest_path: couldn't create query plan via SPI");
212      return -1;
213    }
214
215  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
216    {
217      elog(ERROR, "shortest_path: SPI_cursor_open('%s') returns NULL", sql);
218      return -1;
219    }
220
221  while (moredata == TRUE)
222    {
223      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
224
225      if (edge_columns.id == -1)
226        {
227          if (fetch_edge_columns(SPI_tuptable, &edge_columns,
228                                 has_reverse_cost) == -1)
229            return finish(SPIcode, ret);
230        }
231
232      ntuples = SPI_processed;
233      total_tuples += ntuples;
234      if (!edges)
235        edges = palloc(total_tuples * sizeof(edge_t));
236      else
237        edges = repalloc(edges, total_tuples * sizeof(edge_t));
238
239      if (edges == NULL)
240        {
241          elog(ERROR, "Out of memory");
242            return finish(SPIcode, ret);         
243        }
244
245      if (ntuples > 0)
246        {
247          int t;
248          SPITupleTable *tuptable = SPI_tuptable;
249          TupleDesc tupdesc = SPI_tuptable->tupdesc;
250               
251          for (t = 0; t < ntuples; t++)
252            {
253              HeapTuple tuple = tuptable->vals[t];
254              fetch_edge(&tuple, &tupdesc, &edge_columns,
255                         &edges[total_tuples - ntuples + t]);
256            }
257          SPI_freetuptable(tuptable);
258        }
259      else
260        {
261          moredata = FALSE;
262        }
263    }
264
265  //defining min and max vertex id
266     
267  DBG("Total %i tuples", total_tuples);
268   
269  for(z=0; z<total_tuples; z++)
270  {
271    if(edges[z].source<v_min_id)
272    v_min_id=edges[z].source;
273 
274    if(edges[z].source>v_max_id)
275      v_max_id=edges[z].source;
276                           
277    if(edges[z].target<v_min_id)
278      v_min_id=edges[z].target;
279
280    if(edges[z].target>v_max_id)
281      v_max_id=edges[z].target;     
282                                                                       
283    DBG("%i <-> %i", v_min_id, v_max_id);
284                                                       
285  }
286
287  //:::::::::::::::::::::::::::::::::::: 
288  //:: reducing vertex id (renumbering)
289  //::::::::::::::::::::::::::::::::::::
290  for(z=0; z<total_tuples; z++)
291  {
292    //check if edges[] contains source and target
293    if(edges[z].source == start_vertex || edges[z].target == start_vertex)
294      ++s_count;
295    if(edges[z].source == end_vertex || edges[z].target == end_vertex)
296      ++t_count;
297
298    edges[z].source-=v_min_id;
299    edges[z].target-=v_min_id;
300    DBG("%i - %i", edges[z].source, edges[z].target);     
301  }
302
303  DBG("Total %i tuples", total_tuples);
304
305  if(s_count == 0)
306  {
307    elog(ERROR, "Start vertex was not found.");
308    return -1;
309  }
310     
311  if(t_count == 0)
312  {
313    elog(ERROR, "Target vertex was not found.");
314    return -1;
315  }
316 
317  DBG("Calling boost_dijkstra\n");
318       
319  start_vertex -= v_min_id;
320  end_vertex   -= v_min_id;
321
322  ret = boost_dijkstra(edges, total_tuples, start_vertex, end_vertex,
323                       directed, has_reverse_cost,
324                       path, path_count, &err_msg);
325
326  DBG("SIZE %i\n",*path_count);
327
328  //::::::::::::::::::::::::::::::::
329  //:: restoring original vertex id
330  //::::::::::::::::::::::::::::::::
331  for(z=0;z<*path_count;z++)
332  {
333    //DBG("vetex %i\n",(*path)[z].vertex_id);
334    (*path)[z].vertex_id+=v_min_id;
335  }
336
337  DBG("ret = %i\n", ret);
338
339  DBG("*path_count = %i\n", *path_count);
340
341  DBG("ret = %i\n", ret);
342 
343  if (ret < 0)
344    {
345      //elog(ERROR, "Error computing path: %s", err_msg);
346      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED),
347        errmsg("Error computing path: %s", err_msg)));
348    }
349   
350  return finish(SPIcode, ret);
351}
352
353
354PG_FUNCTION_INFO_V1(shortest_path);
355Datum
356shortest_path(PG_FUNCTION_ARGS)
357{
358  FuncCallContext     *funcctx;
359  int                  call_cntr;
360  int                  max_calls;
361  TupleDesc            tuple_desc;
362  path_element_t      *path;
363
364  /* stuff done only on the first call of the function */
365  if (SRF_IS_FIRSTCALL())
366    {
367      MemoryContext   oldcontext;
368      int path_count = 0;
369      int ret;
370
371      /* create a function context for cross-call persistence */
372      funcctx = SRF_FIRSTCALL_INIT();
373
374      /* switch to memory context appropriate for multiple function calls */
375      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
376
377
378      ret = compute_shortest_path(text2char(PG_GETARG_TEXT_P(0)),
379                                  PG_GETARG_INT32(1),
380                                  PG_GETARG_INT32(2),
381                                  PG_GETARG_BOOL(3),
382                                  PG_GETARG_BOOL(4), &path, &path_count);
383#ifdef DEBUG
384      DBG("Ret is %i", ret);
385      if (ret >= 0)
386        {
387          int i;
388          for (i = 0; i < path_count; i++)
389            {
390              DBG("Step %i vertex_id  %i ", i, path[i].vertex_id);
391              DBG("        edge_id    %i ", path[i].edge_id);
392              DBG("        cost       %f ", path[i].cost);
393            }
394        }
395#endif
396
397      /* total number of tuples to be returned */
398      funcctx->max_calls = path_count;
399      funcctx->user_fctx = path;
400
401      funcctx->tuple_desc =
402        BlessTupleDesc(RelationNameGetTupleDesc("path_result"));
403
404      MemoryContextSwitchTo(oldcontext);
405    }
406
407  /* stuff done on every call of the function */
408  funcctx = SRF_PERCALL_SETUP();
409
410  call_cntr = funcctx->call_cntr;
411  max_calls = funcctx->max_calls;
412  tuple_desc = funcctx->tuple_desc;
413  path = (path_element_t*) funcctx->user_fctx;
414
415  if (call_cntr < max_calls)    /* do when there is more left to send */
416    {
417      HeapTuple    tuple;
418      Datum        result;
419      Datum *values;
420      char* nulls;
421
422      /* This will work for some compilers. If it crashes with segfault, try to change the following block with this one   
423 
424      values = palloc(4 * sizeof(Datum));
425      nulls = palloc(4 * sizeof(char));
426 
427      values[0] = call_cntr;
428      nulls[0] = ' ';
429      values[1] = Int32GetDatum(path[call_cntr].vertex_id);
430      nulls[1] = ' ';
431      values[2] = Int32GetDatum(path[call_cntr].edge_id);
432      nulls[2] = ' ';
433      values[3] = Float8GetDatum(path[call_cntr].cost);
434      nulls[3] = ' ';
435      */
436   
437      values = palloc(3 * sizeof(Datum));
438      nulls = palloc(3 * sizeof(char));
439
440      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
441      nulls[0] = ' ';
442      values[1] = Int32GetDatum(path[call_cntr].edge_id);
443      nulls[1] = ' ';
444      values[2] = Float8GetDatum(path[call_cntr].cost);
445      nulls[2] = ' ';
446                     
447      tuple = heap_formtuple(tuple_desc, values, nulls);
448
449      /* make the tuple into a datum */
450      result = HeapTupleGetDatum(tuple);
451
452      /* clean up (this is not really necessary) */
453      pfree(values);
454      pfree(nulls);
455
456      SRF_RETURN_NEXT(funcctx, result);
457    }
458  else    /* do when there is no more left */
459    {
460      SRF_RETURN_DONE(funcctx);
461    }
462}
463
Note: See TracBrowser for help on using the browser.