root/branches/anton/core/src/dijkstra.c

Revision 28, 11.9 KB (checked in by anton, 3 years ago)

cmake branch added

Line 
1/*
2 * Shortest path algorithm for PostgreSQL
3 *
4 * Copyright (c) 2005 Sylvain Pasche
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "postgres.h"
23#include "executor/spi.h"
24#include "funcapi.h"
25
26#include "dijkstra.h"
27
28Datum shortest_path(PG_FUNCTION_ARGS);
29
30#undef DEBUG
31//#define DEBUG 1
32
33#ifdef DEBUG
34#define DBG(format, arg...)                     \
35    elog(NOTICE, format , ## arg)
36#else
37#define DBG(format, arg...) do { ; } while (0)
38#endif
39
40// The number of tuples to fetch from the SPI cursor at each iteration
41#define TUPLIMIT 1000
42
43#ifndef PG_MODULE_MAGIC
44PG_MODULE_MAGIC;
45#endif
46
47static char *
48text2char(text *in)
49{
50  char *out = palloc(VARSIZE(in));
51
52  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
53  out[VARSIZE(in) - VARHDRSZ] = '\0';
54  return out;
55}
56
57static int
58finish(int code, int ret)
59{
60  code = SPI_finish();
61  if (code  != SPI_OK_FINISH )
62  {
63    elog(ERROR,"couldn't disconnect from SPI");
64    return -1 ;
65  }                     
66  return ret;
67}
68                         
69typedef struct edge_columns
70{
71  int id;
72  int source;
73  int target;
74  int cost;
75  int reverse_cost;
76} edge_columns_t;
77
78static int
79fetch_edge_columns(SPITupleTable *tuptable, edge_columns_t *edge_columns,
80                   bool has_reverse_cost)
81{
82  edge_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "id");
83  edge_columns->source = SPI_fnumber(SPI_tuptable->tupdesc, "source");
84  edge_columns->target = SPI_fnumber(SPI_tuptable->tupdesc, "target");
85  edge_columns->cost = SPI_fnumber(SPI_tuptable->tupdesc, "cost");
86  if (edge_columns->id == SPI_ERROR_NOATTRIBUTE ||
87      edge_columns->source == SPI_ERROR_NOATTRIBUTE ||
88      edge_columns->target == SPI_ERROR_NOATTRIBUTE ||
89      edge_columns->cost == SPI_ERROR_NOATTRIBUTE)
90    {
91      elog(ERROR, "Error, query must return columns "
92           "'id', 'source', 'target' and 'cost'");
93      return -1;
94    }
95
96  if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->source) != INT4OID ||
97      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->target) != INT4OID ||
98      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->cost) != FLOAT8OID)
99    {
100      elog(ERROR, "Error, columns 'source', 'target' must be of type int4, 'cost' must be of type float8");
101      return -1;
102    }
103
104  DBG("columns: id %i source %i target %i cost %i",
105      edge_columns->id, edge_columns->source,
106      edge_columns->target, edge_columns->cost);
107
108  if (has_reverse_cost)
109    {
110      edge_columns->reverse_cost = SPI_fnumber(SPI_tuptable->tupdesc,
111                                               "reverse_cost");
112
113      if (edge_columns->reverse_cost == SPI_ERROR_NOATTRIBUTE)
114        {
115          elog(ERROR, "Error, reverse_cost is used, but query did't return "
116               "'reverse_cost' column");
117          return -1;
118        }
119
120      if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->reverse_cost)
121          != FLOAT8OID)
122        {
123          elog(ERROR, "Error, columns 'reverse_cost' must be of type float8");
124          return -1;
125        }
126
127      DBG("columns: reverse_cost cost %i", edge_columns->reverse_cost);
128    }
129   
130  return 0;
131}
132
133static void
134fetch_edge(HeapTuple *tuple, TupleDesc *tupdesc,
135           edge_columns_t *edge_columns, edge_t *target_edge)
136{
137  Datum binval;
138  bool isnull;
139
140  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->id, &isnull);
141  if (isnull)
142    elog(ERROR, "id contains a null value");
143  target_edge->id = DatumGetInt32(binval);
144
145  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->source, &isnull);
146  if (isnull)
147    elog(ERROR, "source contains a null value");
148  target_edge->source = DatumGetInt32(binval);
149
150  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->target, &isnull);
151  if (isnull)
152    elog(ERROR, "target contains a null value");
153  target_edge->target = DatumGetInt32(binval);
154
155  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->cost, &isnull);
156  if (isnull)
157    elog(ERROR, "cost contains a null value");
158  target_edge->cost = DatumGetFloat8(binval);
159
160  if (edge_columns->reverse_cost != -1)
161    {
162      binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->reverse_cost,
163                             &isnull);
164      if (isnull)
165        elog(ERROR, "reverse_cost contains a null value");
166      target_edge->reverse_cost =  DatumGetFloat8(binval);
167    }
168}
169
170
171static int compute_shortest_path(char* sql, int start_vertex,
172                                 int end_vertex, bool directed,
173                                 bool has_reverse_cost,
174                                 path_element_t **path, int *path_count)
175{
176
177  int SPIcode;
178  void *SPIplan;
179  Portal SPIportal;
180  bool moredata = TRUE;
181  int ntuples;
182  edge_t *edges = NULL;
183  int total_tuples = 0;
184  edge_columns_t edge_columns = {id: -1, source: -1, target: -1,
185                                 cost: -1, reverse_cost: -1};
186  int v_max_id=0;
187  int v_min_id=INT_MAX;
188
189  int s_count = 0;
190  int t_count = 0;
191
192  char *err_msg;
193  int ret = -1;
194  register int z;
195
196  DBG("start shortest_path\n");
197       
198  SPIcode = SPI_connect();
199  if (SPIcode  != SPI_OK_CONNECT)
200    {
201      elog(ERROR, "shortest_path: couldn't open a connection to SPI");
202      return -1;
203    }
204
205  SPIplan = SPI_prepare(sql, 0, NULL);
206  if (SPIplan  == NULL)
207    {
208      elog(ERROR, "shortest_path: couldn't create query plan via SPI");
209      return -1;
210    }
211
212  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
213    {
214      elog(ERROR, "shortest_path: SPI_cursor_open('%s') returns NULL", sql);
215      return -1;
216    }
217
218  while (moredata == TRUE)
219    {
220      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
221
222      if (edge_columns.id == -1)
223        {
224          if (fetch_edge_columns(SPI_tuptable, &edge_columns,
225                                 has_reverse_cost) == -1)
226            return finish(SPIcode, ret);
227        }
228
229      ntuples = SPI_processed;
230      total_tuples += ntuples;
231      if (!edges)
232        edges = palloc(total_tuples * sizeof(edge_t));
233      else
234        edges = repalloc(edges, total_tuples * sizeof(edge_t));
235
236      if (edges == NULL)
237        {
238          elog(ERROR, "Out of memory");
239            return finish(SPIcode, ret);         
240        }
241
242      if (ntuples > 0)
243        {
244          int t;
245          SPITupleTable *tuptable = SPI_tuptable;
246          TupleDesc tupdesc = SPI_tuptable->tupdesc;
247               
248          for (t = 0; t < ntuples; t++)
249            {
250              HeapTuple tuple = tuptable->vals[t];
251              fetch_edge(&tuple, &tupdesc, &edge_columns,
252                         &edges[total_tuples - ntuples + t]);
253            }
254          SPI_freetuptable(tuptable);
255        }
256      else
257        {
258          moredata = FALSE;
259        }
260    }
261
262  //defining min and max vertex id
263     
264  DBG("Total %i tuples", total_tuples);
265   
266  for(z=0; z<total_tuples; z++)
267  {
268    if(edges[z].source<v_min_id)
269    v_min_id=edges[z].source;
270 
271    if(edges[z].source>v_max_id)
272      v_max_id=edges[z].source;
273                           
274    if(edges[z].target<v_min_id)
275      v_min_id=edges[z].target;
276
277    if(edges[z].target>v_max_id)
278      v_max_id=edges[z].target;     
279                                                                       
280    DBG("%i <-> %i", v_min_id, v_max_id);
281                                                       
282  }
283
284  //:::::::::::::::::::::::::::::::::::: 
285  //:: reducing vertex id (renumbering)
286  //::::::::::::::::::::::::::::::::::::
287  for(z=0; z<total_tuples; z++)
288  {
289    //check if edges[] contains source and target
290    if(edges[z].source == start_vertex || edges[z].target == start_vertex)
291      ++s_count;
292    if(edges[z].source == end_vertex || edges[z].target == end_vertex)
293      ++t_count;
294
295    edges[z].source-=v_min_id;
296    edges[z].target-=v_min_id;
297    DBG("%i - %i", edges[z].source, edges[z].target);     
298  }
299
300  DBG("Total %i tuples", total_tuples);
301
302  if(s_count == 0)
303  {
304    elog(ERROR, "Start vertex was not found.");
305    return -1;
306  }
307     
308  if(t_count == 0)
309  {
310    elog(ERROR, "Target vertex was not found.");
311    return -1;
312  }
313 
314  DBG("Calling boost_dijkstra\n");
315       
316  start_vertex -= v_min_id;
317  end_vertex   -= v_min_id;
318
319  ret = boost_dijkstra(edges, total_tuples, start_vertex, end_vertex,
320                       directed, has_reverse_cost,
321                       path, path_count, &err_msg);
322
323  DBG("SIZE %i\n",*path_count);
324
325  //::::::::::::::::::::::::::::::::
326  //:: restoring original vertex id
327  //::::::::::::::::::::::::::::::::
328  for(z=0;z<*path_count;z++)
329  {
330    //DBG("vetex %i\n",(*path)[z].vertex_id);
331    (*path)[z].vertex_id+=v_min_id;
332  }
333
334  DBG("ret = %i\n", ret);
335
336  DBG("*path_count = %i\n", *path_count);
337
338  DBG("ret = %i\n", ret);
339 
340  if (ret < 0)
341    {
342      //elog(ERROR, "Error computing path: %s", err_msg);
343      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED),
344        errmsg("Error computing path: %s", err_msg)));
345    }
346   
347  return finish(SPIcode, ret);
348}
349
350
351PG_FUNCTION_INFO_V1(shortest_path);
352Datum
353shortest_path(PG_FUNCTION_ARGS)
354{
355  FuncCallContext     *funcctx;
356  int                  call_cntr;
357  int                  max_calls;
358  TupleDesc            tuple_desc;
359  path_element_t      *path;
360
361  /* stuff done only on the first call of the function */
362  if (SRF_IS_FIRSTCALL())
363    {
364      MemoryContext   oldcontext;
365      int path_count = 0;
366      int ret;
367
368      /* create a function context for cross-call persistence */
369      funcctx = SRF_FIRSTCALL_INIT();
370
371      /* switch to memory context appropriate for multiple function calls */
372      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
373
374
375      ret = compute_shortest_path(text2char(PG_GETARG_TEXT_P(0)),
376                                  PG_GETARG_INT32(1),
377                                  PG_GETARG_INT32(2),
378                                  PG_GETARG_BOOL(3),
379                                  PG_GETARG_BOOL(4), &path, &path_count);
380#ifdef DEBUG
381      DBG("Ret is %i", ret);
382      if (ret >= 0)
383        {
384          int i;
385          for (i = 0; i < path_count; i++)
386            {
387              DBG("Step %i vertex_id  %i ", i, path[i].vertex_id);
388              DBG("        edge_id    %i ", path[i].edge_id);
389              DBG("        cost       %f ", path[i].cost);
390            }
391        }
392#endif
393
394      /* total number of tuples to be returned */
395      funcctx->max_calls = path_count;
396      funcctx->user_fctx = path;
397
398      funcctx->tuple_desc =
399        BlessTupleDesc(RelationNameGetTupleDesc("path_result"));
400
401      MemoryContextSwitchTo(oldcontext);
402    }
403
404  /* stuff done on every call of the function */
405  funcctx = SRF_PERCALL_SETUP();
406
407  call_cntr = funcctx->call_cntr;
408  max_calls = funcctx->max_calls;
409  tuple_desc = funcctx->tuple_desc;
410  path = (path_element_t*) funcctx->user_fctx;
411
412  if (call_cntr < max_calls)    /* do when there is more left to send */
413    {
414      HeapTuple    tuple;
415      Datum        result;
416      Datum *values;
417      char* nulls;
418
419      values = palloc(3 * sizeof(Datum));
420      nulls = palloc(3 * sizeof(char));
421
422
423      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
424      nulls[0] = ' ';
425      values[1] = Int32GetDatum(path[call_cntr].edge_id);
426      nulls[1] = ' ';
427      values[2] = Float8GetDatum(path[call_cntr].cost);
428      nulls[2] = ' ';
429
430      tuple = heap_formtuple(tuple_desc, values, nulls);
431
432      /* make the tuple into a datum */
433      result = HeapTupleGetDatum(tuple);
434
435      /* clean up (this is not really necessary) */
436      pfree(values);
437      pfree(nulls);
438
439      SRF_RETURN_NEXT(funcctx, result);
440    }
441  else    /* do when there is no more left */
442    {
443      SRF_RETURN_DONE(funcctx);
444    }
445}
446
Note: See TracBrowser for help on using the browser.