root/tags/release-1.0-beta/dijkstra.c

Revision 44, 11.9 KB (checked in by anton, 3 years ago)

PG_MODULE_MAGIC check fixed

Line 
1/*
2 * Shortest path algorithm for PostgreSQL
3 *
4 * Copyright (c) 2005 Sylvain Pasche
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "postgres.h"
23#include "executor/spi.h"
24#include "funcapi.h"
25
26#include "fmgr.h"
27
28#include "dijkstra.h"
29
30Datum shortest_path(PG_FUNCTION_ARGS);
31
32#undef DEBUG
33//#define DEBUG 1
34
35#ifdef DEBUG
36#define DBG(format, arg...)                     \
37    elog(NOTICE, format , ## arg)
38#else
39#define DBG(format, arg...) do { ; } while (0)
40#endif
41
42// The number of tuples to fetch from the SPI cursor at each iteration
43#define TUPLIMIT 1000
44
45#ifdef PG_MODULE_MAGIC
46PG_MODULE_MAGIC;
47#endif
48
49static char *
50text2char(text *in)
51{
52  char *out = palloc(VARSIZE(in));
53
54  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
55  out[VARSIZE(in) - VARHDRSZ] = '\0';
56  return out;
57}
58
59static int
60finish(int code, int ret)
61{
62  code = SPI_finish();
63  if (code  != SPI_OK_FINISH )
64  {
65    elog(ERROR,"couldn't disconnect from SPI");
66    return -1 ;
67  }                     
68  return ret;
69}
70                         
71typedef struct edge_columns
72{
73  int id;
74  int source;
75  int target;
76  int cost;
77  int reverse_cost;
78} edge_columns_t;
79
80static int
81fetch_edge_columns(SPITupleTable *tuptable, edge_columns_t *edge_columns,
82                   bool has_reverse_cost)
83{
84  edge_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "id");
85  edge_columns->source = SPI_fnumber(SPI_tuptable->tupdesc, "source");
86  edge_columns->target = SPI_fnumber(SPI_tuptable->tupdesc, "target");
87  edge_columns->cost = SPI_fnumber(SPI_tuptable->tupdesc, "cost");
88  if (edge_columns->id == SPI_ERROR_NOATTRIBUTE ||
89      edge_columns->source == SPI_ERROR_NOATTRIBUTE ||
90      edge_columns->target == SPI_ERROR_NOATTRIBUTE ||
91      edge_columns->cost == SPI_ERROR_NOATTRIBUTE)
92    {
93      elog(ERROR, "Error, query must return columns "
94           "'id', 'source', 'target' and 'cost'");
95      return -1;
96    }
97
98  if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->source) != INT4OID ||
99      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->target) != INT4OID ||
100      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->cost) != FLOAT8OID)
101    {
102      elog(ERROR, "Error, columns 'source', 'target' must be of type int4, 'cost' must be of type float8");
103      return -1;
104    }
105
106  DBG("columns: id %i source %i target %i cost %i",
107      edge_columns->id, edge_columns->source,
108      edge_columns->target, edge_columns->cost);
109
110  if (has_reverse_cost)
111    {
112      edge_columns->reverse_cost = SPI_fnumber(SPI_tuptable->tupdesc,
113                                               "reverse_cost");
114
115      if (edge_columns->reverse_cost == SPI_ERROR_NOATTRIBUTE)
116        {
117          elog(ERROR, "Error, reverse_cost is used, but query did't return "
118               "'reverse_cost' column");
119          return -1;
120        }
121
122      if (SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->reverse_cost)
123          != FLOAT8OID)
124        {
125          elog(ERROR, "Error, columns 'reverse_cost' must be of type float8");
126          return -1;
127        }
128
129      DBG("columns: reverse_cost cost %i", edge_columns->reverse_cost);
130    }
131   
132  return 0;
133}
134
135static void
136fetch_edge(HeapTuple *tuple, TupleDesc *tupdesc,
137           edge_columns_t *edge_columns, edge_t *target_edge)
138{
139  Datum binval;
140  bool isnull;
141
142  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->id, &isnull);
143  if (isnull)
144    elog(ERROR, "id contains a null value");
145  target_edge->id = DatumGetInt32(binval);
146
147  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->source, &isnull);
148  if (isnull)
149    elog(ERROR, "source contains a null value");
150  target_edge->source = DatumGetInt32(binval);
151
152  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->target, &isnull);
153  if (isnull)
154    elog(ERROR, "target contains a null value");
155  target_edge->target = DatumGetInt32(binval);
156
157  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->cost, &isnull);
158  if (isnull)
159    elog(ERROR, "cost contains a null value");
160  target_edge->cost = DatumGetFloat8(binval);
161
162  if (edge_columns->reverse_cost != -1)
163    {
164      binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->reverse_cost,
165                             &isnull);
166      if (isnull)
167        elog(ERROR, "reverse_cost contains a null value");
168      target_edge->reverse_cost =  DatumGetFloat8(binval);
169    }
170}
171
172
173static int compute_shortest_path(char* sql, int start_vertex,
174                                 int end_vertex, bool directed,
175                                 bool has_reverse_cost,
176                                 path_element_t **path, int *path_count)
177{
178
179  int SPIcode;
180  void *SPIplan;
181  Portal SPIportal;
182  bool moredata = TRUE;
183  int ntuples;
184  edge_t *edges = NULL;
185  int total_tuples = 0;
186  edge_columns_t edge_columns = {id: -1, source: -1, target: -1,
187                                 cost: -1, reverse_cost: -1};
188  int v_max_id=0;
189  int v_min_id=INT_MAX;
190
191  int s_count = 0;
192  int t_count = 0;
193
194  char *err_msg;
195  int ret = -1;
196  register int z;
197
198  DBG("start shortest_path\n");
199       
200  SPIcode = SPI_connect();
201  if (SPIcode  != SPI_OK_CONNECT)
202    {
203      elog(ERROR, "shortest_path: couldn't open a connection to SPI");
204      return -1;
205    }
206
207  SPIplan = SPI_prepare(sql, 0, NULL);
208  if (SPIplan  == NULL)
209    {
210      elog(ERROR, "shortest_path: couldn't create query plan via SPI");
211      return -1;
212    }
213
214  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
215    {
216      elog(ERROR, "shortest_path: SPI_cursor_open('%s') returns NULL", sql);
217      return -1;
218    }
219
220  while (moredata == TRUE)
221    {
222      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
223
224      if (edge_columns.id == -1)
225        {
226          if (fetch_edge_columns(SPI_tuptable, &edge_columns,
227                                 has_reverse_cost) == -1)
228            return finish(SPIcode, ret);
229        }
230
231      ntuples = SPI_processed;
232      total_tuples += ntuples;
233      if (!edges)
234        edges = palloc(total_tuples * sizeof(edge_t));
235      else
236        edges = repalloc(edges, total_tuples * sizeof(edge_t));
237
238      if (edges == NULL)
239        {
240          elog(ERROR, "Out of memory");
241            return finish(SPIcode, ret);         
242        }
243
244      if (ntuples > 0)
245        {
246          int t;
247          SPITupleTable *tuptable = SPI_tuptable;
248          TupleDesc tupdesc = SPI_tuptable->tupdesc;
249               
250          for (t = 0; t < ntuples; t++)
251            {
252              HeapTuple tuple = tuptable->vals[t];
253              fetch_edge(&tuple, &tupdesc, &edge_columns,
254                         &edges[total_tuples - ntuples + t]);
255            }
256          SPI_freetuptable(tuptable);
257        }
258      else
259        {
260          moredata = FALSE;
261        }
262    }
263
264  //defining min and max vertex id
265     
266  DBG("Total %i tuples", total_tuples);
267   
268  for(z=0; z<total_tuples; z++)
269  {
270    if(edges[z].source<v_min_id)
271    v_min_id=edges[z].source;
272 
273    if(edges[z].source>v_max_id)
274      v_max_id=edges[z].source;
275                           
276    if(edges[z].target<v_min_id)
277      v_min_id=edges[z].target;
278
279    if(edges[z].target>v_max_id)
280      v_max_id=edges[z].target;     
281                                                                       
282    DBG("%i <-> %i", v_min_id, v_max_id);
283                                                       
284  }
285
286  //:::::::::::::::::::::::::::::::::::: 
287  //:: reducing vertex id (renumbering)
288  //::::::::::::::::::::::::::::::::::::
289  for(z=0; z<total_tuples; z++)
290  {
291    //check if edges[] contains source and target
292    if(edges[z].source == start_vertex || edges[z].target == start_vertex)
293      ++s_count;
294    if(edges[z].source == end_vertex || edges[z].target == end_vertex)
295      ++t_count;
296
297    edges[z].source-=v_min_id;
298    edges[z].target-=v_min_id;
299    DBG("%i - %i", edges[z].source, edges[z].target);     
300  }
301
302  DBG("Total %i tuples", total_tuples);
303
304  if(s_count == 0)
305  {
306    elog(ERROR, "Start vertex was not found.");
307    return -1;
308  }
309     
310  if(t_count == 0)
311  {
312    elog(ERROR, "Target vertex was not found.");
313    return -1;
314  }
315 
316  DBG("Calling boost_dijkstra\n");
317       
318  start_vertex -= v_min_id;
319  end_vertex   -= v_min_id;
320
321  ret = boost_dijkstra(edges, total_tuples, start_vertex, end_vertex,
322                       directed, has_reverse_cost,
323                       path, path_count, &err_msg);
324
325  DBG("SIZE %i\n",*path_count);
326
327  //::::::::::::::::::::::::::::::::
328  //:: restoring original vertex id
329  //::::::::::::::::::::::::::::::::
330  for(z=0;z<*path_count;z++)
331  {
332    //DBG("vetex %i\n",(*path)[z].vertex_id);
333    (*path)[z].vertex_id+=v_min_id;
334  }
335
336  DBG("ret = %i\n", ret);
337
338  DBG("*path_count = %i\n", *path_count);
339
340  DBG("ret = %i\n", ret);
341 
342  if (ret < 0)
343    {
344      //elog(ERROR, "Error computing path: %s", err_msg);
345      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED),
346        errmsg("Error computing path: %s", err_msg)));
347    }
348   
349  return finish(SPIcode, ret);
350}
351
352
353PG_FUNCTION_INFO_V1(shortest_path);
354Datum
355shortest_path(PG_FUNCTION_ARGS)
356{
357  FuncCallContext     *funcctx;
358  int                  call_cntr;
359  int                  max_calls;
360  TupleDesc            tuple_desc;
361  path_element_t      *path;
362
363  /* stuff done only on the first call of the function */
364  if (SRF_IS_FIRSTCALL())
365    {
366      MemoryContext   oldcontext;
367      int path_count = 0;
368      int ret;
369
370      /* create a function context for cross-call persistence */
371      funcctx = SRF_FIRSTCALL_INIT();
372
373      /* switch to memory context appropriate for multiple function calls */
374      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
375
376
377      ret = compute_shortest_path(text2char(PG_GETARG_TEXT_P(0)),
378                                  PG_GETARG_INT32(1),
379                                  PG_GETARG_INT32(2),
380                                  PG_GETARG_BOOL(3),
381                                  PG_GETARG_BOOL(4), &path, &path_count);
382#ifdef DEBUG
383      DBG("Ret is %i", ret);
384      if (ret >= 0)
385        {
386          int i;
387          for (i = 0; i < path_count; i++)
388            {
389              DBG("Step %i vertex_id  %i ", i, path[i].vertex_id);
390              DBG("        edge_id    %i ", path[i].edge_id);
391              DBG("        cost       %f ", path[i].cost);
392            }
393        }
394#endif
395
396      /* total number of tuples to be returned */
397      funcctx->max_calls = path_count;
398      funcctx->user_fctx = path;
399
400      funcctx->tuple_desc =
401        BlessTupleDesc(RelationNameGetTupleDesc("path_result"));
402
403      MemoryContextSwitchTo(oldcontext);
404    }
405
406  /* stuff done on every call of the function */
407  funcctx = SRF_PERCALL_SETUP();
408
409  call_cntr = funcctx->call_cntr;
410  max_calls = funcctx->max_calls;
411  tuple_desc = funcctx->tuple_desc;
412  path = (path_element_t*) funcctx->user_fctx;
413
414  if (call_cntr < max_calls)    /* do when there is more left to send */
415    {
416      HeapTuple    tuple;
417      Datum        result;
418      Datum *values;
419      char* nulls;
420
421      values = palloc(3 * sizeof(Datum));
422      nulls = palloc(3 * sizeof(char));
423
424
425      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
426      nulls[0] = ' ';
427      values[1] = Int32GetDatum(path[call_cntr].edge_id);
428      nulls[1] = ' ';
429      values[2] = Float8GetDatum(path[call_cntr].cost);
430      nulls[2] = ' ';
431
432      tuple = heap_formtuple(tuple_desc, values, nulls);
433
434      /* make the tuple into a datum */
435      result = HeapTupleGetDatum(tuple);
436
437      /* clean up (this is not really necessary) */
438      pfree(values);
439      pfree(nulls);
440
441      SRF_RETURN_NEXT(funcctx, result);
442    }
443  else    /* do when there is no more left */
444    {
445      SRF_RETURN_DONE(funcctx);
446    }
447}
448
Note: See TracBrowser for help on using the browser.